1use std::any::Any;
2use std::sync::Arc;
3use std::{fmt::Debug, str::FromStr};
4
5use anyhow::Result;
6use candle_core::{DType, Device, Tensor, D};
7use candle_nn::Conv2dConfig;
8use image::{ColorType, DynamicImage};
9use itertools::Itertools;
10use mistralrs_quant::log::once_log_info;
11use mistralrs_quant::ShardedVarBuilder;
12
13#[cfg(feature = "pyo3_macros")]
14use pyo3::pyclass;
15
16use regex::Regex;
17use serde::Deserialize;
18
19use self::minicpmo::{MiniCpmOConfig, MiniCpmOModel, MiniCpmOProcessor};
20
21use super::{DeviceMappedModelLoader, NonMappedSubModel, NormalLoadingMetadata};
22use crate::amoe::AnyMoeBaseModelMixin;
23use crate::device_map::DeviceMapper;
24use crate::layers::Conv3dConfig;
25use crate::paged_attention::{AttentionImplementation, ModelConfigLike, ModelConfigMetadata};
26use crate::pipeline::isq::IsqModelLoader;
27use crate::pipeline::loaders::AutoDeviceMapParams;
28use crate::pipeline::text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata};
29use crate::pipeline::{EitherCache, IsqModel, Processor, ProcessorCreator, VisionPromptPrefixer};
30use crate::utils::varbuilder_utils::DeviceForLoadTensor;
31use crate::vision_models::clip::ClipConfig;
32use crate::vision_models::gemma3::config::Gemma3Config;
33use crate::vision_models::gemma3::{Gemma3Model, Gemma3Processor};
34use crate::vision_models::idefics2::{Config as Idefics2Config, Idefics2};
35use crate::vision_models::idefics2_input_processor::Idefics2Processor;
36use crate::vision_models::idefics3::{Idefics3Config, Idefics3Model, Idefics3Processor};
37use crate::vision_models::image_processor::ImagePreProcessor;
38use crate::vision_models::inputs_processor::Phi4MMProcessor;
39use crate::vision_models::llama4::{
40 self, Llama4Config, Llama4ImageProcessor, Llama4Model, Llama4Processor,
41};
42use crate::vision_models::llava::config::Config as LLaVAConfig;
43use crate::vision_models::llava15::Model as LLaVA;
44use crate::vision_models::llava_inputs_processor::{self, LLaVAProcessor};
45use crate::vision_models::llava_next::Model as LLaVANext;
46use crate::vision_models::llava_next_inputs_processor::{self, LLaVANextProcessor};
47use crate::vision_models::mistral3::{Mistral3Config, Mistral3Model, Mistral3Processor};
48use crate::vision_models::mllama::{MLlamaConfig, MLlamaModel, MLlamaProcessor};
49use crate::vision_models::phi3::{Config as Phi3Config, Model as Phi3, PHI3V_CLIP_CONFIG};
50use crate::vision_models::phi3_inputs_processor::Phi3Processor;
51use crate::vision_models::phi4::{Phi4MMConfig, Phi4MMModel, PHI4_MM_VISION_CFG};
52use crate::vision_models::preprocessor_config::PreProcessorConfig;
53use crate::vision_models::processor_config::ProcessorConfig;
54use crate::vision_models::qwen2_5_vl::{
55 Config as Qwen2_5VLConfig, Qwen2_5VLModel, Qwen2_5VLProcessor,
56};
57use crate::vision_models::qwen2vl::{Config as Qwen2VLConfig, Qwen2VLModel, Qwen2VLProcessor};
58use crate::vision_models::{minicpmo, phi4};
59
60pub trait VisionModel: IsqModel + AnyMoeBaseModelMixin {
61 #[allow(clippy::too_many_arguments)]
63 fn forward(
64 &self,
65 input_ids: &Tensor,
66 pixel_values: Option<Tensor>,
67 seqlen_offsets: &[usize],
68 context_lens: Vec<(usize, usize)>,
69 position_ids: Vec<usize>,
70 model_specific_args: Box<dyn Any>, metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
72 flash_params: &FlashParams,
73 ) -> candle_core::Result<Tensor>;
74 fn device(&self) -> &Device;
75 fn cache(&self) -> &EitherCache;
76 fn cache_mut(&mut self) -> &mut EitherCache;
77 fn max_seq_len(&self) -> usize;
78 fn config(&self) -> &ModelConfigMetadata;
79 fn default_model_specific_args(&self, input_ids: &Tensor) -> Box<dyn Any>;
81}
82
83pub trait VisionModelLoader: IsqModelLoader + Send + Sync + DeviceMappedModelLoader {
84 fn load(
85 &self,
86 config: &str,
87 vb: ShardedVarBuilder,
88 normal_loading_metadata: NormalLoadingMetadata,
89 attention_mechanism: AttentionImplementation,
90 ) -> Result<Box<dyn VisionModel + Send + Sync>>;
91 fn is_gptx(&self, config: &str) -> bool;
92 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>>;
93 fn get_processor(
94 &self,
95 model_config: &str,
96 processor_config: Option<ProcessorConfig>,
97 preprocessor_config: PreProcessorConfig,
98 max_edge: Option<u32>,
99 ) -> Arc<dyn Processor + Send + Sync>;
100 fn supports_paged_attention(&self, config: &str) -> bool;
101 fn supports_prefix_cacher(&self, _config: &str) -> bool {
102 false
104 }
105 fn prefixer(&self, config: &str) -> Arc<dyn VisionPromptPrefixer>;
106 fn get_device_for_tensor(
107 &self,
108 config: &str,
109 _mapper: &dyn DeviceMapper,
110 loading_isq: bool,
111 ) -> Result<Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>> {
112 if loading_isq {
113 Ok(Arc::new(|_| DeviceForLoadTensor::Base))
114 } else {
115 let re = Regex::new(r"\.layers\.(\d+)\.").unwrap();
116 let num_layers = self.model_config(config)?.num_layers();
117 let closure = move |name: String| {
118 if let Some(captures) = re.captures(&name) {
119 captures
120 .get(1)
121 .and_then(|m| m.as_str().parse::<usize>().ok())
122 .map(|l| l.min(num_layers))
123 .map(DeviceForLoadTensor::Idx)
124 .unwrap_or(DeviceForLoadTensor::Base)
125 } else {
126 DeviceForLoadTensor::Base
127 }
128 };
129
130 Ok(Arc::new(closure))
131 }
132 }
133}
134
135#[cfg_attr(feature = "pyo3_macros", pyclass(eq, eq_int))]
136#[derive(Clone, Debug, Deserialize, PartialEq)]
137pub enum VisionLoaderType {
139 #[serde(rename = "phi3v")]
140 Phi3V,
141 #[serde(rename = "idefics2")]
142 Idefics2,
143 #[serde(rename = "llava_next")]
144 LLaVANext,
145 #[serde(rename = "llava")]
146 LLaVA,
147 #[serde(rename = "vllama")]
148 VLlama,
149 #[serde(rename = "qwen2vl")]
150 Qwen2VL,
151 #[serde(rename = "idefics3")]
152 Idefics3,
153 #[serde(rename = "minicpmo")]
154 MiniCpmO,
155 #[serde(rename = "phi4mm")]
156 Phi4MM,
157 #[serde(rename = "qwen2_5vl")]
158 Qwen2_5VL,
159 #[serde(rename = "gemma3")]
160 Gemma3,
161 #[serde(rename = "mistral3")]
162 Mistral3,
163 #[serde(rename = "llama4")]
164 Llama4,
165}
166
167impl VisionLoaderType {
169 pub fn from_causal_lm_name(name: &str) -> Result<Self> {
170 match name {
171 "Phi3VForCausalLM" => Ok(Self::Phi3V),
172 "Idefics2ForConditionalGeneration" => Ok(Self::Idefics2),
173 "LlavaNextForConditionalGeneration" => Ok(Self::LLaVANext),
174 "LlavaForConditionalGeneration" => Ok(Self::LLaVA),
175 "MllamaForConditionalGeneration" => Ok(Self::VLlama),
176 "Qwen2VLForConditionalGeneration" => Ok(Self::Qwen2VL),
177 "Idefics3ForConditionalGeneration" => Ok(Self::Idefics3),
178 "MiniCPMO" => Ok(Self::MiniCpmO),
179 "Phi4MMForCausalLM" => Ok(Self::Phi4MM),
180 "Qwen2_5_VLForConditionalGeneration" => Ok(Self::Qwen2_5VL),
181 "Gemma3ForConditionalGeneration" | "Gemma3ForCausalLM" => Ok(Self::Gemma3),
182 "Mistral3ForConditionalGeneration" => Ok(Self::Mistral3),
183 "Llama4ForConditionalGeneration" => Ok(Self::Llama4),
184 other => anyhow::bail!(
185 "Unsupported Hugging Face Transformers -CausalLM model class `{other}`. Please raise an issue."
186 ),
187 }
188 }
189}
190
191impl FromStr for VisionLoaderType {
192 type Err = String;
193 fn from_str(s: &str) -> Result<Self, Self::Err> {
194 match s {
195 "phi3v" => Ok(Self::Phi3V),
196 "idefics2" => Ok(Self::Idefics2),
197 "llava_next" => Ok(Self::LLaVANext),
198 "llava" => Ok(Self::LLaVA),
199 "vllama" => Ok(Self::VLlama),
200 "qwen2vl" => Ok(Self::Qwen2VL),
201 "idefics3" => Ok(Self::Idefics3),
202 "minicpmo" => Ok(Self::MiniCpmO),
203 "phi4mm" => Ok(Self::Phi4MM),
204 "qwen2_5vl" => Ok(Self::Qwen2_5VL),
205 "gemma3" => Ok(Self::Gemma3),
206 "mistral3" => Ok(Self::Mistral3),
207 "llama4" => Ok(Self::Llama4),
208 a => Err(format!("Unknown architecture `{a}`. Possible architectures: `phi3v`, `idefics2`, `llava_next`, `llava`, `vllama`, `qwen2vl`, `idefics3`, `minicpmo`, `phi4mm`, `qwen2_5vl`, `gemma3`, `mistral3`, `llama4`.")),
209 }
210 }
211}
212
213impl std::fmt::Display for VisionLoaderType {
214 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
215 let name = match self {
216 VisionLoaderType::Phi3V => "phi3v",
217 VisionLoaderType::Idefics2 => "idefics2",
218 VisionLoaderType::LLaVANext => "llava_next",
219 VisionLoaderType::LLaVA => "llava",
220 VisionLoaderType::VLlama => "vllama",
221 VisionLoaderType::Qwen2VL => "qwen2vl",
222 VisionLoaderType::Idefics3 => "idefics3",
223 VisionLoaderType::MiniCpmO => "minicpmo",
224 VisionLoaderType::Phi4MM => "phi4mm",
225 VisionLoaderType::Qwen2_5VL => "qwen2_5vl",
226 VisionLoaderType::Gemma3 => "gemma3",
227 VisionLoaderType::Mistral3 => "mistral3",
228 VisionLoaderType::Llama4 => "llama4",
229 };
230 write!(f, "{name}")
231 }
232}
233
234#[derive(Deserialize)]
235struct AutoVisionLoaderConfig {
236 architectures: Vec<String>,
237}
238
239pub struct AutoVisionLoader;
241
242impl AutoVisionLoader {
243 fn get_loader(config: &str) -> Result<Box<dyn VisionModelLoader>> {
244 let auto_cfg: AutoVisionLoaderConfig = serde_json::from_str(config)?;
245 if auto_cfg.architectures.len() != 1 {
246 anyhow::bail!("Expected exactly one architecture in config");
247 }
248
249 let name = &auto_cfg.architectures[0];
250 let tp = VisionLoaderType::from_causal_lm_name(name)?;
251
252 once_log_info(format!("Automatic loader type determined to be `{tp}`"));
253
254 Ok(match tp {
256 VisionLoaderType::Phi3V => Box::new(Phi3VLoader),
257 VisionLoaderType::Idefics2 => Box::new(Idefics2Loader),
258 VisionLoaderType::LLaVANext => Box::new(LLaVANextLoader),
259 VisionLoaderType::LLaVA => Box::new(LLaVALoader),
260 VisionLoaderType::VLlama => Box::new(VLlamaLoader),
261 VisionLoaderType::Qwen2VL => Box::new(Qwen2VLLoader),
262 VisionLoaderType::Idefics3 => Box::new(Idefics3Loader),
263 VisionLoaderType::MiniCpmO => Box::new(MiniCpmOLoader),
264 VisionLoaderType::Phi4MM => Box::new(Phi4MMLoader),
265 VisionLoaderType::Qwen2_5VL => Box::new(Qwen2_5VLLoader),
266 VisionLoaderType::Gemma3 => Box::new(Gemma3Loader),
267 VisionLoaderType::Mistral3 => Box::new(Mistral3Loader),
268 VisionLoaderType::Llama4 => Box::new(VLlama4Loader),
269 })
270 }
271}
272
273impl VisionModelLoader for AutoVisionLoader {
274 fn load(
275 &self,
276 config: &str,
277 vb: ShardedVarBuilder,
278 normal_loading_metadata: NormalLoadingMetadata,
279 attention_mechanism: AttentionImplementation,
280 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
281 Self::get_loader(config)?.load(config, vb, normal_loading_metadata, attention_mechanism)
282 }
283
284 fn is_gptx(&self, config: &str) -> bool {
285 Self::get_loader(config)
286 .expect("AutoVisionLoader get_loader")
287 .is_gptx(config)
288 }
289
290 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
291 Self::get_loader(config)?.get_config_repr(config)
292 }
293
294 fn get_processor(
295 &self,
296 model_config: &str,
297 proc_cfg: Option<ProcessorConfig>,
298 preproc_cfg: PreProcessorConfig,
299 max_edge: Option<u32>,
300 ) -> Arc<dyn Processor + Send + Sync> {
301 Self::get_loader(model_config)
302 .expect("AutoVisionLoader get_loader")
303 .get_processor(model_config, proc_cfg, preproc_cfg, max_edge)
304 }
305
306 fn supports_paged_attention(&self, config: &str) -> bool {
307 Self::get_loader(config)
308 .expect("AutoVisionLoader")
309 .supports_paged_attention(config)
310 }
311
312 fn supports_prefix_cacher(&self, config: &str) -> bool {
313 Self::get_loader(config)
314 .expect("AutoVisionLoader")
315 .supports_prefix_cacher(config)
316 }
317
318 fn prefixer(&self, config: &str) -> Arc<dyn VisionPromptPrefixer> {
319 Self::get_loader(config)
320 .expect("AutoVisionLoader")
321 .prefixer(config)
322 }
323
324 fn get_device_for_tensor(
325 &self,
326 config: &str,
327 mapper: &dyn DeviceMapper,
328 loading_isq: bool,
329 ) -> Result<Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>> {
330 Self::get_loader(config)?.get_device_for_tensor(config, mapper, loading_isq)
331 }
332}
333
334impl IsqModelLoader for AutoVisionLoader {
335 fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
336 Self::get_loader(config)?.isq_layer_regexes(config)
337 }
338 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
339 Self::get_loader(config)?.immediate_isq_predicates(config)
340 }
341}
342
343impl DeviceMappedModelLoader for AutoVisionLoader {
344 fn mapped_max_act_size_elems(
345 &self,
346 config: &str,
347 params: &AutoDeviceMapParams,
348 prompt_chunksize: usize,
349 ) -> Result<usize> {
350 Self::get_loader(config)?.mapped_max_act_size_elems(config, params, prompt_chunksize)
351 }
352 fn non_mapped_max_act_size_elems(
353 &self,
354 config: &str,
355 params: &AutoDeviceMapParams,
356 ) -> Result<usize> {
357 Self::get_loader(config)?.non_mapped_max_act_size_elems(config, params)
358 }
359 fn non_mapped_size_in_bytes(
360 &self,
361 config: &str,
362 dtype: DType,
363 weight_pack_factor: usize,
364 ) -> Result<usize> {
365 Self::get_loader(config)?.non_mapped_size_in_bytes(config, dtype, weight_pack_factor)
366 }
367 fn layer_sizes_in_bytes(
368 &self,
369 config: &str,
370 dtype: DType,
371 weight_pack_factor: usize,
372 ) -> Result<Vec<usize>> {
373 Self::get_loader(config)?.layer_sizes_in_bytes(config, dtype, weight_pack_factor)
374 }
375 fn num_layers(&self, config: &str) -> Result<usize> {
376 Self::get_loader(config)?.num_layers(config)
377 }
378 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
379 Self::get_loader(config)?.model_config(config)
380 }
381}
382
383macro_rules! bias_if {
384 ($cond:expr, $size:expr) => {
385 if $cond {
386 $size
387 } else {
388 0
389 }
390 };
391}
392
393fn get_clip_vit_num_elems(cfg: &ClipConfig) -> usize {
394 let pre_layer_norm = cfg.hidden_size;
395 let final_layer_norm = cfg.hidden_size;
396
397 let num_patches = (cfg.image_size / cfg.patch_size).pow(2);
398 let num_positions = num_patches + 1;
399
400 let class_embedding = cfg.hidden_size;
401
402 let position_ids = num_positions;
403 let position_embedding = num_positions * cfg.hidden_size;
404
405 let conv2dconfig = Conv2dConfig {
406 stride: cfg.patch_size,
407 ..Default::default()
408 };
409 let patch_embedding =
410 cfg.num_channels * cfg.hidden_size / conv2dconfig.groups * cfg.patch_size * cfg.patch_size;
411
412 let encoder_layer_elems = {
413 let layer_norm1 = cfg.hidden_size;
414 let layer_norm2 = cfg.hidden_size;
415
416 let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
417 let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
418 let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
419 let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
420
421 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
422 let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
423
424 layer_norm1 + layer_norm2 + q_proj + k_proj + v_proj + o_proj + fc1 + fc2
425 };
426
427 pre_layer_norm
428 + final_layer_norm
429 + class_embedding
430 + position_ids
431 + position_embedding
432 + patch_embedding
433 + cfg.num_hidden_layers * encoder_layer_elems
434}
435
436pub struct Phi3VLoader;
442
443pub struct Phi3VPrefixer;
444
445impl VisionPromptPrefixer for Phi3VPrefixer {
446 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
447 format!(
449 "{}{prompt}",
450 image_indexes
451 .into_iter()
452 .map(|image_index| format!("<|image_{}|>", image_index + 1))
453 .join("")
454 )
455 }
456}
457
458impl VisionModelLoader for Phi3VLoader {
459 fn load(
460 &self,
461 config: &str,
462 vb: ShardedVarBuilder,
463 normal_loading_metadata: NormalLoadingMetadata,
464 attention_mechanism: AttentionImplementation,
465 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
466 let cfg: crate::vision_models::phi3::Config = serde_json::from_str(config)?;
467 Ok(Box::new(Phi3::new(
468 &cfg,
469 vb,
470 self.is_gptx(config),
471 normal_loading_metadata,
472 attention_mechanism,
473 )?))
474 }
475 fn is_gptx(&self, _config: &str) -> bool {
476 true
477 }
478 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
479 let cfg: crate::vision_models::phi3::Config = serde_json::from_str(config)?;
480 Ok(Box::new(cfg))
481 }
482 fn get_processor(
483 &self,
484 _model_config: &str,
485 processor_config: Option<ProcessorConfig>,
486 preprocessor_config: PreProcessorConfig,
487 _max_edge: Option<u32>,
488 ) -> Arc<dyn Processor + Send + Sync> {
489 Phi3Processor::new_processor(processor_config, preprocessor_config)
490 }
491 fn supports_paged_attention(&self, _config: &str) -> bool {
492 true
493 }
494 fn supports_prefix_cacher(&self, _config: &str) -> bool {
495 true
496 }
497 fn prefixer(&self, _config: &str) -> Arc<dyn VisionPromptPrefixer> {
498 Arc::new(Phi3VPrefixer)
499 }
500}
501
502impl IsqModelLoader for Phi3VLoader {
503 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
504 Ok(vec![
505 Regex::new(r"lm_head\.(weight|bias)$")?,
506 Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?,
508 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
509 Regex::new(r"layers\.(\d+)\.mlp\.gate_up_proj\.(weight|bias)$")?,
511 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
512 ])
513 }
514 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
515 self.isq_layer_regexes(config)
516 }
517}
518
519impl DeviceMappedModelLoader for Phi3VLoader {
520 fn mapped_max_act_size_elems(
521 &self,
522 config: &str,
523 params: &AutoDeviceMapParams,
524 _prompt_chunksize: usize,
525 ) -> Result<usize> {
526 let AutoDeviceMapParams::Vision {
528 max_seq_len,
529 max_batch_size,
530 max_image_shape: _,
531 max_num_images,
532 } = params
533 else {
534 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
535 };
536
537 let cfg: Phi3Config = serde_json::from_str(config)?;
538
539 let vcfg = &PHI3V_CLIP_CONFIG;
540
541 let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
542 let img_seq_len = (num_patches + 1) * max_num_images;
543
544 let max_text_attn = {
545 let max_seq_len = img_seq_len + max_seq_len;
547 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
548 };
549
550 Ok(max_text_attn)
551 }
552
553 fn non_mapped_max_act_size_elems(
554 &self,
555 config: &str,
556 params: &AutoDeviceMapParams,
557 ) -> Result<usize> {
558 let AutoDeviceMapParams::Vision {
560 max_seq_len: _,
561 max_batch_size,
562 max_image_shape: _,
563 max_num_images,
564 } = params
565 else {
566 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
567 };
568
569 let cfg: Phi3Config = serde_json::from_str(config)?;
570
571 let vcfg = &PHI3V_CLIP_CONFIG;
572
573 let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
574 let img_seq_len = num_patches + 1;
575
576 let max_vision_attn = {
577 (max_batch_size * max_num_images) * cfg.num_attention_heads * img_seq_len * img_seq_len
578 };
579
580 Ok(max_vision_attn)
581 }
582
583 fn non_mapped_size_in_bytes(
584 &self,
585 config: &str,
586 dtype: DType,
587 weight_pack_factor: usize,
588 ) -> Result<usize> {
589 let cfg: Phi3Config = serde_json::from_str(config)?;
590 let elems = {
591 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
592 let lm_head = if !cfg.tie_word_embeddings {
593 cfg.hidden_size * cfg.vocab_size
594 } else {
595 0
596 };
597 let norm = cfg.hidden_size;
598
599 let image_embed = {
600 let projection_cls = cfg
601 .embd_layer
602 .projection_cls
603 .clone()
604 .unwrap_or("linear".to_string());
605 let with_learnable_separator =
606 cfg.embd_layer.with_learnable_separator.unwrap_or(false);
607 let use_hd_transform = cfg.embd_layer.use_hd_transform.unwrap_or(false);
608 let image_dim_out = cfg.img_processor.image_dim_out;
609
610 let proj = match (projection_cls.as_str(), use_hd_transform) {
611 ("linear", _) => image_dim_out * cfg.hidden_size + cfg.hidden_size,
612 ("mlp", true) => {
613 let a = (image_dim_out * 4) * cfg.hidden_size + cfg.hidden_size;
614 let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
615 a + b
616 }
617 ("mlp", false) => {
618 let a = image_dim_out * cfg.hidden_size + cfg.hidden_size;
619 let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
620 a + b
621 }
622 _ => {
623 anyhow::bail!("projection_cls=`{projection_cls}` not implemented.");
624 }
625 };
626
627 let (glb_gn, sub_gn) = if with_learnable_separator {
628 let glb_gn = image_dim_out * 4;
629 let sub_gn = image_dim_out * 4;
630 (glb_gn, sub_gn)
631 } else {
632 (0, 0)
633 };
634
635 let clip_vit = get_clip_vit_num_elems(&PHI3V_CLIP_CONFIG);
636
637 proj + glb_gn + sub_gn + clip_vit
638 };
639
640 embed_tokens + lm_head + norm + image_embed
641 };
642
643 Ok(elems * dtype.size_in_bytes())
644 }
645
646 fn layer_sizes_in_bytes(
647 &self,
648 config: &str,
649 dtype: DType,
650 weight_pack_factor: usize,
651 ) -> Result<Vec<usize>> {
652 let cfg: Phi3Config = serde_json::from_str(config)?;
653 let per_layer_elems = {
654 let input_layernorm = cfg.hidden_size;
655 let post_attention_layernorm = cfg.hidden_size;
656
657 let size_in = cfg.hidden_size;
658 let head_dim = cfg.head_dim();
659 let op_size =
660 cfg.num_attention_heads * head_dim + 2 * cfg.num_key_value_heads * head_dim;
661 let qkv_proj = size_in * op_size / weight_pack_factor;
662 let o_proj = (cfg.num_attention_heads * head_dim) * size_in / weight_pack_factor;
663
664 let h_size = cfg.hidden_size;
665 let i_size = cfg.intermediate_size;
666 let gate_up_proj = h_size * (2 * i_size) / weight_pack_factor;
667 let down_proj = h_size * i_size / weight_pack_factor;
668
669 input_layernorm
670 + post_attention_layernorm
671 + qkv_proj
672 + o_proj
673 + gate_up_proj
674 + down_proj
675 };
676 Ok(vec![
677 per_layer_elems * dtype.size_in_bytes();
678 cfg.num_hidden_layers
679 ])
680 }
681
682 fn num_layers(&self, config: &str) -> Result<usize> {
683 let cfg: Phi3Config = serde_json::from_str(config)?;
684 Ok(cfg.num_hidden_layers)
685 }
686
687 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
688 let cfg: Phi3Config = serde_json::from_str(config)?;
689
690 let cfg = ModelConfigMetadata {
691 max_seq_len: cfg.max_position_embeddings,
692 num_layers: cfg.num_hidden_layers,
693 hidden_size: cfg.hidden_size,
694 num_kv_heads: cfg.num_key_value_heads,
695 num_attn_heads: cfg.num_attention_heads,
696 sliding_window: cfg.sliding_window,
697 k_head_dim: cfg.head_dim(),
698 v_head_dim: cfg.head_dim(),
699 };
700
701 Ok(Box::new(cfg))
702 }
703
704 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
705 Some(vec![NonMappedSubModel::Vision])
706 }
707}
708
709pub struct Idefics2Loader;
715
716pub struct Idefics2Prefixer;
717
718impl VisionPromptPrefixer for Idefics2Prefixer {
719 fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
720 prompt.to_string()
722 }
723}
724
725impl VisionModelLoader for Idefics2Loader {
726 fn load(
727 &self,
728 config: &str,
729 vb: ShardedVarBuilder,
730 normal_loading_metadata: NormalLoadingMetadata,
731 attention_mechanism: AttentionImplementation,
732 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
733 let cfg: crate::vision_models::idefics2::Config = serde_json::from_str(config)?;
734 Ok(Box::new(Idefics2::new(
735 &cfg,
736 vb,
737 self.is_gptx(config),
738 normal_loading_metadata,
739 attention_mechanism,
740 )?))
741 }
742 fn is_gptx(&self, _config: &str) -> bool {
743 true
744 }
745 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
746 let cfg: crate::vision_models::idefics2::Config = serde_json::from_str(config)?;
747 Ok(Box::new(cfg))
748 }
749 fn get_processor(
750 &self,
751 _model_config: &str,
752 processor_config: Option<ProcessorConfig>,
753 preprocessor_config: PreProcessorConfig,
754 max_edge: Option<u32>,
755 ) -> Arc<dyn Processor + Send + Sync> {
756 Arc::new(Idefics2Processor::new(
757 processor_config.unwrap(),
758 preprocessor_config,
759 max_edge,
760 ))
761 }
762 fn supports_paged_attention(&self, _config: &str) -> bool {
763 true
764 }
765 fn supports_prefix_cacher(&self, _config: &str) -> bool {
766 true
767 }
768 fn prefixer(&self, _config: &str) -> Arc<dyn VisionPromptPrefixer> {
769 Arc::new(Idefics2Prefixer)
770 }
771}
772
773impl IsqModelLoader for Idefics2Loader {
774 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
775 Ok(vec![
776 Regex::new(r"lm_head\.(weight|bias)$")?,
777 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
779 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
780 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
781 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
782 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
784 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
785 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
786 ])
787 }
788 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
789 Ok(vec![
790 Regex::new(r"lm_head\.(weight|bias)$")?,
791 Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
793 Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
794 Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
795 Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
796 Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
798 Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
799 Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
800 ])
801 }
802}
803
804impl DeviceMappedModelLoader for Idefics2Loader {
805 fn mapped_max_act_size_elems(
806 &self,
807 config: &str,
808 params: &AutoDeviceMapParams,
809 _prompt_chunksize: usize,
810 ) -> Result<usize> {
811 let AutoDeviceMapParams::Vision {
812 max_seq_len,
813 max_batch_size,
814 max_image_shape: _,
815 max_num_images,
816 } = params
817 else {
818 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
819 };
820
821 let cfg: Idefics2Config = serde_json::from_str(config)?;
822
823 let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
824 let img_seq_len = (num_patches + 1) * max_num_images;
825
826 let max_text_attn = {
827 let max_seq_len = img_seq_len + max_seq_len;
829 max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len
830 };
831
832 Ok(max_text_attn)
833 }
834
835 fn non_mapped_max_act_size_elems(
836 &self,
837 config: &str,
838 params: &AutoDeviceMapParams,
839 ) -> Result<usize> {
840 let AutoDeviceMapParams::Vision {
841 max_seq_len: _,
842 max_batch_size,
843 max_image_shape: _,
844 max_num_images,
845 } = params
846 else {
847 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
848 };
849
850 let cfg: Idefics2Config = serde_json::from_str(config)?;
851
852 let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
853 let img_seq_len = num_patches + 1;
854
855 let max_vision_attn = {
856 let images_factor = 5;
858
859 (max_batch_size * images_factor * max_num_images)
860 * cfg.vision_config.num_attention_heads
861 * img_seq_len
862 * img_seq_len
863 };
864
865 Ok(max_vision_attn)
866 }
867
868 fn non_mapped_size_in_bytes(
869 &self,
870 config: &str,
871 dtype: DType,
872 weight_pack_factor: usize,
873 ) -> Result<usize> {
874 let cfg: Idefics2Config = serde_json::from_str(config)?;
875 let text_elems = {
876 let tie_word_embeddings = cfg.tie_word_embeddings;
877 let cfg = &cfg.text_config;
878
879 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
880 let lm_head = if !tie_word_embeddings {
881 cfg.hidden_size * cfg.vocab_size
882 } else {
883 0
884 };
885 let norm = cfg.hidden_size;
886 embed_tokens + lm_head + norm
887 };
888
889 let connector_elems = {
890 let tcfg = &cfg.text_config;
891 let vcfg = &cfg.vision_config;
892 let gate_proj = vcfg.hidden_size * tcfg.intermediate_size;
893 let up_proj = vcfg.hidden_size * tcfg.intermediate_size;
894 let down_proj = tcfg.intermediate_size * tcfg.hidden_size;
895
896 let perceiver_elems = {
897 let tcfg = &cfg.text_config;
898 let pcfg = &cfg.perceiver_config;
899
900 let n_latents = pcfg.resampler_n_latents;
901 let hidden_size = tcfg.hidden_size;
902 let depth = pcfg.resampler_depth;
903
904 let norm = tcfg.hidden_size;
905 let latents = n_latents * hidden_size;
906
907 let layer_elems = {
908 let input_latents_norm = hidden_size;
909 let input_context_norm = hidden_size;
910 let post_attn_norm = hidden_size;
911
912 let num_heads = pcfg.resampler_n_heads;
913 let head_dim = pcfg.resampler_head_dim;
914 let num_key_value_heads = pcfg.num_key_value_heads;
915
916 let q_proj = hidden_size * num_heads * head_dim;
917 let k_proj = hidden_size * num_key_value_heads * head_dim;
918 let v_proj = hidden_size * num_key_value_heads * head_dim;
919 let o_proj = num_heads * head_dim * hidden_size;
920
921 let gate_proj = hidden_size * hidden_size * 4;
922 let up_proj = hidden_size * hidden_size * 4;
923 let down_proj = hidden_size * 4 * hidden_size;
924
925 input_latents_norm
926 + input_context_norm
927 + post_attn_norm
928 + q_proj
929 + k_proj
930 + v_proj
931 + o_proj
932 + gate_proj
933 + up_proj
934 + down_proj
935 };
936
937 norm + latents + layer_elems * depth
938 };
939
940 gate_proj + up_proj + down_proj + perceiver_elems
941 };
942
943 let vision_transformer = {
944 let cfg = &cfg.vision_config;
945
946 let post_layernorm = cfg.hidden_size;
947
948 let conv_config = Conv2dConfig {
949 stride: cfg.patch_size,
950 ..Default::default()
951 };
952 let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
953 * cfg.patch_size
954 * cfg.patch_size;
955
956 let num_patches_per_side = cfg.image_size / cfg.patch_size;
957 let num_patches = num_patches_per_side.pow(2);
958 let position_embedding = num_patches * cfg.hidden_size;
959
960 let layer_elems = {
961 let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
962 let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
963
964 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
965 let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
966
967 let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
968 let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
969 let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
970 let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
971
972 layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
973 };
974
975 post_layernorm + patch_embedding + position_embedding + layer_elems
976 };
977
978 let elems = text_elems + connector_elems + vision_transformer;
979
980 Ok(elems * dtype.size_in_bytes())
981 }
982
983 fn layer_sizes_in_bytes(
984 &self,
985 config: &str,
986 dtype: DType,
987 weight_pack_factor: usize,
988 ) -> Result<Vec<usize>> {
989 let cfg: Idefics2Config = serde_json::from_str(config)?;
990 let cfg = cfg.text_config;
991 let per_layer_elems = {
992 let input_layernorm = cfg.hidden_size;
993 let post_attention_layernorm = cfg.hidden_size;
994
995 let size_in = cfg.hidden_size;
996 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
997 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
998 let q_proj = size_in * size_q / weight_pack_factor;
999 let k_proj = size_in * size_kv / weight_pack_factor;
1000 let v_proj = size_in * size_kv / weight_pack_factor;
1001 let o_proj = size_q * size_in / weight_pack_factor;
1002
1003 let h_size = cfg.hidden_size;
1004 let i_size = cfg.intermediate_size;
1005 let gate_proj = h_size * i_size / weight_pack_factor;
1006 let up_proj = h_size * i_size / weight_pack_factor;
1007 let down_proj = i_size * h_size / weight_pack_factor;
1008
1009 input_layernorm
1010 + post_attention_layernorm
1011 + q_proj
1012 + k_proj
1013 + v_proj
1014 + o_proj
1015 + gate_proj
1016 + up_proj
1017 + down_proj
1018 };
1019 Ok(vec![
1020 per_layer_elems * dtype.size_in_bytes();
1021 cfg.num_hidden_layers
1022 ])
1023 }
1024
1025 fn num_layers(&self, config: &str) -> Result<usize> {
1026 let cfg: Idefics2Config = serde_json::from_str(config)?;
1027 Ok(cfg.text_config.num_hidden_layers)
1028 }
1029 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1030 let cfg: Idefics2Config = serde_json::from_str(config)?;
1031 let cfg = &cfg.text_config;
1032
1033 let cfg = ModelConfigMetadata {
1034 max_seq_len: cfg.max_position_embeddings,
1035 num_layers: cfg.num_hidden_layers,
1036 hidden_size: cfg.hidden_size,
1037 num_kv_heads: cfg.num_key_value_heads,
1038 num_attn_heads: cfg.num_attention_heads,
1039 sliding_window: cfg.sliding_window,
1040 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1041 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1042 };
1043
1044 Ok(Box::new(cfg))
1045 }
1046
1047 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1048 Some(vec![NonMappedSubModel::Vision])
1049 }
1050}
1051
1052pub struct LLaVANextLoader;
1058
1059pub struct LLaVANextPrefixer;
1060
1061impl VisionPromptPrefixer for LLaVANextPrefixer {
1062 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
1063 format!("{}{prompt}", "<image>".repeat(image_indexes.len()))
1064 }
1065}
1066
1067impl VisionModelLoader for LLaVANextLoader {
1068 fn load(
1069 &self,
1070 config: &str,
1071 vb: ShardedVarBuilder,
1072 normal_loading_metadata: NormalLoadingMetadata,
1073 attention_mechanism: AttentionImplementation,
1074 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
1075 let cfg: crate::vision_models::llava::config::Config = serde_json::from_str(config)?;
1076 Ok(Box::new(LLaVANext::new(
1077 &cfg,
1078 vb,
1079 self.is_gptx(config),
1080 normal_loading_metadata,
1081 attention_mechanism,
1082 )?))
1083 }
1084 fn is_gptx(&self, _config: &str) -> bool {
1085 false
1086 }
1087 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1088 let cfg: crate::vision_models::llava::config::Config = serde_json::from_str(config)?;
1089 Ok(Box::new(cfg))
1090 }
1091 fn get_processor(
1092 &self,
1093 model_config: &str,
1094 _processor_config: Option<ProcessorConfig>,
1095 _preprocessor_config: PreProcessorConfig,
1096 _max_edge: Option<u32>,
1097 ) -> Arc<dyn Processor + Send + Sync> {
1098 Arc::new(LLaVANextProcessor::new(model_config))
1099 }
1100 fn supports_paged_attention(&self, _config: &str) -> bool {
1101 true
1102 }
1103 fn supports_prefix_cacher(&self, _config: &str) -> bool {
1104 true
1105 }
1106 fn prefixer(&self, _config: &str) -> Arc<dyn VisionPromptPrefixer> {
1107 Arc::new(LLaVANextPrefixer)
1108 }
1109}
1110
1111impl IsqModelLoader for LLaVANextLoader {
1112 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1113 Ok(vec![
1114 Regex::new(r"lm_head\.(weight|bias)$")?,
1115 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1117 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1118 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1119 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1120 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1122 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1123 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1124 ])
1125 }
1126 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
1127 Ok(vec![
1128 Regex::new(r"lm_head\.(weight|bias)$")?,
1129 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1131 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1132 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1133 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1134 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1136 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1137 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1138 ])
1139 }
1140}
1141
1142impl DeviceMappedModelLoader for LLaVANextLoader {
1143 fn mapped_max_act_size_elems(
1144 &self,
1145 config: &str,
1146 params: &AutoDeviceMapParams,
1147 _prompt_chunksize: usize,
1148 ) -> Result<usize> {
1149 let AutoDeviceMapParams::Vision {
1150 max_seq_len,
1151 max_batch_size,
1152 max_image_shape,
1153 max_num_images,
1154 } = params
1155 else {
1156 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1157 };
1158
1159 let config: LLaVAConfig = serde_json::from_str(config)?;
1160
1161 #[allow(clippy::cast_possible_truncation)]
1162 let img_seq_len =
1163 llava_next_inputs_processor::LLaVANextInputProcessor::get_num_image_tokens(
1164 &config,
1165 (max_image_shape.0 as u32, max_image_shape.1 as u32),
1166 );
1167 let img_seq_len = img_seq_len * max_num_images;
1168
1169 let max_text_attn = {
1170 let cfg = &config.text_config;
1171 let max_seq_len = img_seq_len + max_seq_len;
1173
1174 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
1175 };
1176
1177 Ok(max_text_attn)
1178 }
1179
1180 fn non_mapped_max_act_size_elems(
1181 &self,
1182 config: &str,
1183 params: &AutoDeviceMapParams,
1184 ) -> Result<usize> {
1185 let AutoDeviceMapParams::Vision {
1186 max_seq_len: _,
1187 max_batch_size,
1188 max_image_shape,
1189 max_num_images,
1190 } = params
1191 else {
1192 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1193 };
1194
1195 let config: LLaVAConfig = serde_json::from_str(config)?;
1196
1197 #[allow(clippy::cast_possible_truncation)]
1198 let img_seq_len =
1199 llava_next_inputs_processor::LLaVANextInputProcessor::get_num_image_tokens(
1200 &config,
1201 (max_image_shape.0 as u32, max_image_shape.1 as u32),
1202 );
1203
1204 let max_vision_attn = {
1205 (max_batch_size * max_num_images)
1206 * config.vision_config.num_attention_heads
1207 * img_seq_len
1208 * img_seq_len
1209 };
1210
1211 Ok(max_vision_attn)
1212 }
1213
1214 fn non_mapped_size_in_bytes(
1215 &self,
1216 config: &str,
1217 dtype: DType,
1218 weight_pack_factor: usize,
1219 ) -> Result<usize> {
1220 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1221 let text_elems = {
1222 let cfg = &cfg.text_config;
1223 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1224 let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1225 let norm = cfg.hidden_size;
1226 embed_tokens + lm_head + norm
1227 };
1228
1229 let image_newline = cfg.text_config.hidden_size;
1230 let mmproj = {
1231 let linear_1 = cfg.vision_config.hidden_size * cfg.text_config.hidden_size
1232 + cfg.text_config.hidden_size;
1233 let linear_2 = cfg.text_config.hidden_size * cfg.text_config.hidden_size
1234 + cfg.text_config.hidden_size;
1235
1236 linear_1 + linear_2
1237 };
1238 let vision_tower = get_clip_vit_num_elems(&cfg.to_clip_config());
1239
1240 let elems = text_elems + image_newline + mmproj + vision_tower;
1241 Ok(elems * dtype.size_in_bytes())
1242 }
1243
1244 fn layer_sizes_in_bytes(
1245 &self,
1246 config: &str,
1247 dtype: DType,
1248 weight_pack_factor: usize,
1249 ) -> Result<Vec<usize>> {
1250 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1251 let per_layer_elems = {
1252 let cfg = &cfg.text_config;
1253 let input_layernorm = cfg.hidden_size;
1254 let post_attention_layernorm = cfg.hidden_size;
1255
1256 let size_in = cfg.hidden_size;
1257 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1258 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1259 let q_proj = size_in * size_q / weight_pack_factor;
1260 let k_proj = size_in * size_kv / weight_pack_factor;
1261 let v_proj = size_in * size_kv / weight_pack_factor;
1262 let o_proj = size_q * size_in / weight_pack_factor;
1263
1264 let h_size = cfg.hidden_size;
1265 let i_size = cfg.intermediate_size;
1266 let gate_proj = h_size * i_size / weight_pack_factor;
1267 let up_proj = h_size * i_size / weight_pack_factor;
1268 let down_proj = i_size * h_size / weight_pack_factor;
1269
1270 input_layernorm
1271 + post_attention_layernorm
1272 + q_proj
1273 + k_proj
1274 + v_proj
1275 + o_proj
1276 + gate_proj
1277 + up_proj
1278 + down_proj
1279 };
1280 Ok(vec![
1281 per_layer_elems * dtype.size_in_bytes();
1282 cfg.text_config.num_hidden_layers
1283 ])
1284 }
1285
1286 fn num_layers(&self, config: &str) -> Result<usize> {
1287 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1288 Ok(cfg.text_config.num_hidden_layers)
1289 }
1290
1291 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1292 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1293 let cfg = &cfg.text_config;
1294
1295 let cfg = ModelConfigMetadata {
1296 max_seq_len: cfg.max_position_embeddings,
1297 num_layers: cfg.num_hidden_layers,
1298 hidden_size: cfg.hidden_size,
1299 num_kv_heads: cfg.num_key_value_heads,
1300 num_attn_heads: cfg.num_attention_heads,
1301 sliding_window: cfg.sliding_window,
1302 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1303 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1304 };
1305
1306 Ok(Box::new(cfg))
1307 }
1308
1309 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1310 Some(vec![NonMappedSubModel::Vision])
1311 }
1312}
1313
1314pub struct LLaVALoader;
1320
1321pub struct LLaVAPrefixer;
1322
1323impl VisionPromptPrefixer for LLaVAPrefixer {
1324 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
1325 format!("{}{prompt}", "<image>".repeat(image_indexes.len()))
1326 }
1327}
1328
1329impl VisionModelLoader for LLaVALoader {
1330 fn load(
1331 &self,
1332 config: &str,
1333 vb: ShardedVarBuilder,
1334 normal_loading_metadata: NormalLoadingMetadata,
1335 attention_mechanism: AttentionImplementation,
1336 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
1337 let cfg: crate::vision_models::llava::config::Config = serde_json::from_str(config)?;
1338 Ok(Box::new(LLaVA::new(
1339 &cfg,
1340 vb,
1341 self.is_gptx(config),
1342 normal_loading_metadata,
1343 attention_mechanism,
1344 )?))
1345 }
1346 fn is_gptx(&self, _config: &str) -> bool {
1347 false
1348 }
1349 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1350 let cfg: crate::vision_models::llava::config::Config = serde_json::from_str(config)?;
1351 Ok(Box::new(cfg))
1352 }
1353 fn get_processor(
1354 &self,
1355 model_config: &str,
1356 _processor_config: Option<ProcessorConfig>,
1357 _preprocessor_config: PreProcessorConfig,
1358 _max_edge: Option<u32>,
1359 ) -> Arc<dyn Processor + Send + Sync> {
1360 Arc::new(LLaVAProcessor::new(model_config))
1361 }
1362 fn supports_paged_attention(&self, _config: &str) -> bool {
1363 true
1364 }
1365 fn supports_prefix_cacher(&self, _config: &str) -> bool {
1366 true
1367 }
1368 fn prefixer(&self, _config: &str) -> Arc<dyn VisionPromptPrefixer> {
1369 Arc::new(LLaVAPrefixer)
1370 }
1371}
1372
1373impl IsqModelLoader for LLaVALoader {
1374 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1375 Ok(vec![
1376 Regex::new(r"lm_head\.(weight|bias)$")?,
1377 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1379 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1380 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1381 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1382 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1384 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1385 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1386 ])
1387 }
1388 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
1389 Ok(vec![
1390 Regex::new(r"lm_head\.(weight|bias)$")?,
1391 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1393 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1394 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1395 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1396 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1398 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1399 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1400 ])
1401 }
1402}
1403
1404impl DeviceMappedModelLoader for LLaVALoader {
1405 fn mapped_max_act_size_elems(
1406 &self,
1407 config: &str,
1408 params: &AutoDeviceMapParams,
1409 _prompt_chunksize: usize,
1410 ) -> Result<usize> {
1411 let AutoDeviceMapParams::Vision {
1412 max_seq_len,
1413 max_batch_size,
1414 max_image_shape: _,
1415 max_num_images,
1416 } = params
1417 else {
1418 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1419 };
1420
1421 let config: LLaVAConfig = serde_json::from_str(config)?;
1422
1423 let img_seq_len =
1424 llava_inputs_processor::LLaVAInputProcessor::get_num_image_tokens(&config);
1425 let img_seq_len = img_seq_len * max_num_images;
1426
1427 let max_text_attn = {
1428 let cfg = &config.text_config;
1429 let max_seq_len = img_seq_len + max_seq_len;
1431
1432 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
1433 };
1434
1435 Ok(max_text_attn)
1436 }
1437
1438 fn non_mapped_max_act_size_elems(
1439 &self,
1440 config: &str,
1441 params: &AutoDeviceMapParams,
1442 ) -> Result<usize> {
1443 let AutoDeviceMapParams::Vision {
1444 max_seq_len: _,
1445 max_batch_size,
1446 max_image_shape: _,
1447 max_num_images,
1448 } = params
1449 else {
1450 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1451 };
1452
1453 let config: LLaVAConfig = serde_json::from_str(config)?;
1454
1455 let img_seq_len =
1456 llava_inputs_processor::LLaVAInputProcessor::get_num_image_tokens(&config);
1457
1458 let max_vision_attn = {
1459 (max_batch_size * max_num_images)
1460 * config.vision_config.num_attention_heads
1461 * img_seq_len
1462 * img_seq_len
1463 };
1464
1465 Ok(max_vision_attn)
1466 }
1467
1468 fn non_mapped_size_in_bytes(
1469 &self,
1470 config: &str,
1471 dtype: DType,
1472 weight_pack_factor: usize,
1473 ) -> Result<usize> {
1474 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1475 let text_elems = {
1476 let cfg = &cfg.text_config;
1477 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1478 let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1479 let norm = cfg.hidden_size;
1480 embed_tokens + lm_head + norm
1481 };
1482
1483 let image_newline = cfg.text_config.hidden_size;
1484 let mmproj = {
1485 let linear_1 = cfg.vision_config.hidden_size * cfg.text_config.hidden_size
1486 + cfg.text_config.hidden_size;
1487 let linear_2 = cfg.text_config.hidden_size * cfg.text_config.hidden_size
1488 + cfg.text_config.hidden_size;
1489
1490 linear_1 + linear_2
1491 };
1492 let vision_tower = get_clip_vit_num_elems(&cfg.to_clip_config());
1493
1494 let elems = text_elems + image_newline + mmproj + vision_tower;
1495 Ok(elems * dtype.size_in_bytes())
1496 }
1497
1498 fn layer_sizes_in_bytes(
1499 &self,
1500 config: &str,
1501 dtype: DType,
1502 weight_pack_factor: usize,
1503 ) -> Result<Vec<usize>> {
1504 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1505 let per_layer_elems = {
1506 let cfg = &cfg.text_config;
1507 let input_layernorm = cfg.hidden_size;
1508 let post_attention_layernorm = cfg.hidden_size;
1509
1510 let size_in = cfg.hidden_size;
1511 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1512 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1513 let q_proj = size_in * size_q / weight_pack_factor;
1514 let k_proj = size_in * size_kv / weight_pack_factor;
1515 let v_proj = size_in * size_kv / weight_pack_factor;
1516 let o_proj = size_q * size_in / weight_pack_factor;
1517
1518 let h_size = cfg.hidden_size;
1519 let i_size = cfg.intermediate_size;
1520 let gate_proj = h_size * i_size / weight_pack_factor;
1521 let up_proj = h_size * i_size / weight_pack_factor;
1522 let down_proj = i_size * h_size / weight_pack_factor;
1523
1524 input_layernorm
1525 + post_attention_layernorm
1526 + q_proj
1527 + k_proj
1528 + v_proj
1529 + o_proj
1530 + gate_proj
1531 + up_proj
1532 + down_proj
1533 };
1534 Ok(vec![
1535 per_layer_elems * dtype.size_in_bytes();
1536 cfg.text_config.num_hidden_layers
1537 ])
1538 }
1539
1540 fn num_layers(&self, config: &str) -> Result<usize> {
1541 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1542 Ok(cfg.text_config.num_hidden_layers)
1543 }
1544
1545 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1546 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1547 let cfg = &cfg.text_config;
1548
1549 let cfg = ModelConfigMetadata {
1550 max_seq_len: cfg.max_position_embeddings,
1551 num_layers: cfg.num_hidden_layers,
1552 hidden_size: cfg.hidden_size,
1553 num_kv_heads: cfg.num_key_value_heads,
1554 num_attn_heads: cfg.num_attention_heads,
1555 sliding_window: cfg.sliding_window,
1556 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1557 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1558 };
1559
1560 Ok(Box::new(cfg))
1561 }
1562
1563 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1564 Some(vec![NonMappedSubModel::Vision])
1565 }
1566}
1567
1568pub struct VLlamaLoader;
1574
1575pub struct VLlamaPrefixer;
1576
1577impl VisionPromptPrefixer for VLlamaPrefixer {
1578 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
1579 format!("{}{prompt}", "<|image|>".repeat(image_indexes.len()))
1580 }
1581}
1582
1583impl VisionModelLoader for VLlamaLoader {
1584 fn load(
1585 &self,
1586 config: &str,
1587 vb: ShardedVarBuilder,
1588 normal_loading_metadata: NormalLoadingMetadata,
1589 attention_mechanism: AttentionImplementation,
1590 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
1591 let cfg: crate::vision_models::mllama::MLlamaConfig = serde_json::from_str(config)?;
1592 Ok(Box::new(MLlamaModel::new(
1593 &cfg,
1594 vb,
1595 self.is_gptx(config),
1596 normal_loading_metadata,
1597 attention_mechanism,
1598 )?))
1599 }
1600 fn is_gptx(&self, _config: &str) -> bool {
1601 true
1602 }
1603 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1604 let cfg: crate::vision_models::mllama::MLlamaConfig = serde_json::from_str(config)?;
1605 Ok(Box::new(cfg))
1606 }
1607 fn get_processor(
1608 &self,
1609 _model_config: &str,
1610 _processor_config: Option<ProcessorConfig>,
1611 _preprocessor_config: PreProcessorConfig,
1612 _max_edge: Option<u32>,
1613 ) -> Arc<dyn Processor + Send + Sync> {
1614 Arc::new(MLlamaProcessor::new())
1615 }
1616 fn supports_paged_attention(&self, _config: &str) -> bool {
1617 false
1618 }
1619 fn supports_prefix_cacher(&self, _config: &str) -> bool {
1620 true
1621 }
1622 fn prefixer(&self, _config: &str) -> Arc<dyn VisionPromptPrefixer> {
1623 Arc::new(VLlamaPrefixer)
1624 }
1625}
1626
1627impl IsqModelLoader for VLlamaLoader {
1628 fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
1629 let config: MLlamaConfig = serde_json::from_str(config)?;
1630 let cross_attn_layers = &config.text_config.cross_attention_layers;
1631 let transformer_layers =
1632 (0..config.text_config.num_hidden_layers).filter(|i| !cross_attn_layers.contains(i));
1633 let mut text_regexes = Vec::new();
1634 for layer in transformer_layers {
1635 text_regexes.extend(vec![
1636 Regex::new(&format!(
1638 r"language_model.model.layers\.{layer}\.self_attn\.q_proj\.(weight|bias)$"
1639 ))?,
1640 Regex::new(&format!(
1641 r"language_model.model.layers\.{layer}\.self_attn\.k_proj\.(weight|bias)$"
1642 ))?,
1643 Regex::new(&format!(
1644 r"language_model.model.layers\.{layer}\.self_attn\.v_proj\.(weight|bias)$"
1645 ))?,
1646 Regex::new(&format!(
1647 r"language_model.model.layers\.{layer}\.self_attn\.o_proj\.(weight|bias)$"
1648 ))?,
1649 Regex::new(&format!(
1651 r"language_model.model.layers\.{layer}\.mlp\.gate_proj\.(weight|bias)$"
1652 ))?,
1653 Regex::new(&format!(
1654 r"language_model.model.layers\.{layer}\.mlp\.up_proj\.(weight|bias)$"
1655 ))?,
1656 Regex::new(&format!(
1657 r"language_model.model.layers\.{layer}\.mlp\.down_proj\.(weight|bias)$"
1658 ))?,
1659 ]);
1660 }
1661 let vision_regexes = vec![
1662 Regex::new(
1664 r"vision_model.transformer.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
1665 )?,
1666 Regex::new(
1667 r"vision_model.transformer.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$",
1668 )?,
1669 Regex::new(
1670 r"vision_model.transformer.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$",
1671 )?,
1672 Regex::new(
1673 r"vision_model.transformer.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$",
1674 )?,
1675 Regex::new(
1677 r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
1678 )?,
1679 Regex::new(
1680 r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$",
1681 )?,
1682 Regex::new(
1683 r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$",
1684 )?,
1685 Regex::new(
1686 r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$",
1687 )?,
1688 Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
1690 Regex::new(r"layers\.(\d+)\.mlp\.fc2\.(weight|bias)$")?,
1691 ];
1692
1693 Ok([text_regexes, vision_regexes].concat())
1694 }
1695 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1696 self.isq_layer_regexes(config)
1697 }
1698}
1699
1700impl DeviceMappedModelLoader for VLlamaLoader {
1701 fn mapped_max_act_size_elems(
1702 &self,
1703 config: &str,
1704 params: &AutoDeviceMapParams,
1705 _prompt_chunksize: usize,
1706 ) -> Result<usize> {
1707 let AutoDeviceMapParams::Vision {
1708 max_seq_len,
1709 max_batch_size,
1710 max_image_shape: _,
1711 max_num_images,
1712 } = params
1713 else {
1714 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1715 };
1716
1717 let config: MLlamaConfig = serde_json::from_str(config)?;
1718
1719 let img_seq_len = {
1720 let cfg = &config.vision_config;
1721 let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
1722 let num_padding_patches = (8 - (num_patches as isize % 8)) % 8;
1723 cfg.max_num_tiles * (num_patches as isize + num_padding_patches) as usize
1724 };
1725 let img_seq_len = img_seq_len * max_num_images;
1726
1727 let max_cross_text_attn = {
1728 let cfg = &config.text_config;
1729 max_batch_size * cfg.num_attention_heads * img_seq_len * img_seq_len
1730 };
1731
1732 let max_self_text_attn = {
1733 let cfg = &config.text_config;
1734 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
1735 };
1736
1737 Ok(max_self_text_attn.max(max_cross_text_attn))
1738 }
1739
1740 fn non_mapped_max_act_size_elems(
1741 &self,
1742 config: &str,
1743 params: &AutoDeviceMapParams,
1744 ) -> Result<usize> {
1745 let AutoDeviceMapParams::Vision {
1746 max_seq_len: _,
1747 max_batch_size,
1748 max_image_shape: _,
1749 max_num_images,
1750 } = params
1751 else {
1752 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1753 };
1754
1755 let config: MLlamaConfig = serde_json::from_str(config)?;
1756
1757 let img_seq_len = {
1758 let cfg = &config.vision_config;
1759 let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
1760 let num_padding_patches = (8 - (num_patches as isize % 8)) % 8;
1761 cfg.max_num_tiles * (num_patches as isize + num_padding_patches) as usize
1762 };
1763 let max_vision_attn = {
1764 let cfg = &config.vision_config;
1765 (max_batch_size * max_num_images) * cfg.num_attention_heads * img_seq_len * img_seq_len
1766 };
1767
1768 Ok(max_vision_attn)
1769 }
1770
1771 fn non_mapped_size_in_bytes(
1772 &self,
1773 config: &str,
1774 dtype: DType,
1775 weight_pack_factor: usize,
1776 ) -> Result<usize> {
1777 let config: MLlamaConfig = serde_json::from_str(config)?;
1778 let text_elems = {
1779 let cfg = &config.text_config;
1780 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1781 let lm_head = if !cfg.tie_word_embeddings {
1782 cfg.hidden_size * cfg.vocab_size
1783 } else {
1784 0
1785 };
1786 let norm = cfg.hidden_size;
1787 embed_tokens + lm_head + norm
1788 };
1789
1790 let vision_elems = {
1791 let cfg = &config.vision_config;
1792
1793 let conv_cfg = Conv2dConfig {
1794 stride: cfg.patch_size,
1795 ..Default::default()
1796 };
1797 let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_cfg.groups
1798 * cfg.patch_size
1799 * cfg.patch_size;
1800
1801 let class_embedding = cfg.hidden_size;
1802
1803 let gated_positional_embedding = {
1804 let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
1805 let embedding = num_patches * cfg.hidden_size;
1806 let tile_embedding = (cfg.max_aspect_ratio_id() + 1)
1807 * (cfg.max_num_tiles * num_patches * cfg.hidden_size);
1808
1809 embedding + tile_embedding
1810 };
1811
1812 let pre_tile_positional_embedding =
1813 (cfg.max_aspect_ratio_id() + 1) * (cfg.max_num_tiles * cfg.hidden_size);
1814 let post_tile_positional_embedding =
1815 (cfg.max_aspect_ratio_id() + 1) * (cfg.max_num_tiles * cfg.hidden_size);
1816
1817 let layernorm_pre = cfg.hidden_size;
1818 let layernorm_post = cfg.hidden_size;
1819
1820 let encoder_layer = {
1821 let input_layernorm = cfg.hidden_size + cfg.hidden_size;
1822 let post_attention_layernorm = cfg.hidden_size + cfg.hidden_size;
1823
1824 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
1825 let q_proj =
1826 cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1827 let k_proj =
1828 cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1829 let v_proj =
1830 cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1831 let o_proj =
1832 cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1833
1834 let fc1 = (cfg.hidden_size * cfg.intermediate_size) / weight_pack_factor
1835 + cfg.intermediate_size;
1836 let fc2 = (cfg.intermediate_size * cfg.hidden_size) / weight_pack_factor
1837 + cfg.hidden_size;
1838
1839 input_layernorm
1840 + post_attention_layernorm
1841 + q_proj
1842 + k_proj
1843 + v_proj
1844 + o_proj
1845 + fc1
1846 + fc2
1847 };
1848
1849 patch_embedding
1850 + class_embedding
1851 + gated_positional_embedding
1852 + pre_tile_positional_embedding
1853 + post_tile_positional_embedding
1854 + layernorm_pre
1855 + layernorm_post
1856 + encoder_layer * (cfg.num_hidden_layers + cfg.num_global_layers)
1857 };
1858
1859 let elems = text_elems + vision_elems;
1860 Ok(elems * dtype.size_in_bytes())
1861 }
1862
1863 fn layer_sizes_in_bytes(
1864 &self,
1865 config: &str,
1866 dtype: DType,
1867 weight_pack_factor: usize,
1868 ) -> Result<Vec<usize>> {
1869 let config: MLlamaConfig = serde_json::from_str(config)?;
1870 let cfg = &config.text_config;
1871
1872 let mut layer_sizes = Vec::new();
1873
1874 for i in 0..cfg.num_hidden_layers {
1875 let weight_pack_factor = if cfg.cross_attention_layers.contains(&i) {
1876 1
1878 } else {
1879 weight_pack_factor
1880 };
1881
1882 let per_layer_elems = {
1883 let input_layernorm = cfg.hidden_size;
1884 let post_attention_layernorm = cfg.hidden_size;
1885
1886 let size_in = cfg.hidden_size;
1887 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1888 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1889 let q_proj = size_in * size_q / weight_pack_factor;
1890 let k_proj = size_in * size_kv / weight_pack_factor;
1891 let v_proj = size_in * size_kv / weight_pack_factor;
1892 let o_proj = size_q * size_in / weight_pack_factor;
1893
1894 let h_size = cfg.hidden_size;
1895 let i_size = cfg.intermediate_size;
1896 let gate_proj = h_size * i_size / weight_pack_factor;
1897 let up_proj = h_size * i_size / weight_pack_factor;
1898 let down_proj = i_size * h_size / weight_pack_factor;
1899
1900 input_layernorm
1901 + post_attention_layernorm
1902 + q_proj
1903 + k_proj
1904 + v_proj
1905 + o_proj
1906 + gate_proj
1907 + up_proj
1908 + down_proj
1909 };
1910
1911 layer_sizes.push(per_layer_elems * dtype.size_in_bytes());
1912 }
1913
1914 Ok(layer_sizes)
1915 }
1916
1917 fn num_layers(&self, config: &str) -> Result<usize> {
1918 let config: MLlamaConfig = serde_json::from_str(config)?;
1919 Ok(config.text_config.num_hidden_layers)
1920 }
1921
1922 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1923 let cfg: MLlamaConfig = serde_json::from_str(config)?;
1924 let cfg = &cfg.text_config;
1925
1926 let cfg = ModelConfigMetadata {
1927 max_seq_len: cfg.max_position_embeddings,
1928 num_layers: cfg.num_hidden_layers,
1929 hidden_size: cfg.hidden_size,
1930 num_kv_heads: cfg.num_key_value_heads,
1931 num_attn_heads: cfg.num_attention_heads,
1932 sliding_window: None,
1933 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1934 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1935 };
1936
1937 Ok(Box::new(cfg))
1938 }
1939
1940 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1941 Some(vec![NonMappedSubModel::Vision])
1942 }
1943}
1944
1945pub struct Qwen2VLLoader;
1951
1952pub struct Qwen2VLPrefixer;
1953
1954impl VisionPromptPrefixer for Qwen2VLPrefixer {
1955 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
1956 format!(
1957 "{}{prompt}",
1958 format!(
1959 "{}{}{}",
1960 Qwen2VLProcessor::VISION_START,
1961 Qwen2VLProcessor::IMAGE_PAD,
1962 Qwen2VLProcessor::VISION_END
1963 )
1964 .repeat(image_indexes.len())
1965 )
1966 }
1967}
1968
1969impl VisionModelLoader for Qwen2VLLoader {
1970 fn load(
1971 &self,
1972 config: &str,
1973 vb: ShardedVarBuilder,
1974 normal_loading_metadata: NormalLoadingMetadata,
1975 attention_mechanism: AttentionImplementation,
1976 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
1977 let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
1978 Ok(Box::new(Qwen2VLModel::new(
1979 &cfg,
1980 vb,
1981 self.is_gptx(config),
1982 normal_loading_metadata,
1983 attention_mechanism,
1984 )?))
1985 }
1986 fn is_gptx(&self, _config: &str) -> bool {
1987 true
1988 }
1989 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1990 let config: Qwen2VLConfig = serde_json::from_str(config)?;
1991 Ok(Box::new(config))
1992 }
1993 fn get_processor(
1994 &self,
1995 _model_config: &str,
1996 _processor_config: Option<ProcessorConfig>,
1997 _preprocessor_config: PreProcessorConfig,
1998 max_edge: Option<u32>,
1999 ) -> Arc<dyn Processor + Send + Sync> {
2000 Arc::new(Qwen2VLProcessor::new(max_edge))
2001 }
2002 fn supports_paged_attention(&self, _config: &str) -> bool {
2003 false
2004 }
2005 fn prefixer(&self, _config: &str) -> Arc<dyn VisionPromptPrefixer> {
2006 Arc::new(Qwen2VLPrefixer)
2007 }
2008}
2009
2010impl IsqModelLoader for Qwen2VLLoader {
2011 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2012 Ok(vec![
2013 Regex::new(r"lm_head\.(weight|bias)$")?,
2014 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2016 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2017 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2018 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2019 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2021 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2022 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2023 ])
2024 }
2025 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2026 self.isq_layer_regexes(config)
2027 }
2028}
2029
2030impl DeviceMappedModelLoader for Qwen2VLLoader {
2031 fn mapped_max_act_size_elems(
2032 &self,
2033 config: &str,
2034 params: &AutoDeviceMapParams,
2035 _prompt_chunksize: usize,
2036 ) -> Result<usize> {
2037 let AutoDeviceMapParams::Vision {
2038 max_seq_len,
2039 max_batch_size,
2040 max_image_shape,
2041 max_num_images,
2042 } = params
2043 else {
2044 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2045 };
2046
2047 let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2048
2049 let img_seq_len = {
2050 let cfg = &cfg.vision_config;
2051 let grid_t = max_num_images / cfg.temporal_patch_size;
2052 let grid_h = max_image_shape.0 / cfg.patch_size;
2053 let grid_w = max_image_shape.1 / cfg.patch_size;
2054 grid_t * grid_h * grid_w
2055 };
2056 let img_seq_len = img_seq_len * max_num_images;
2057
2058 let max_text_attn = {
2059 let max_seq_len = img_seq_len + max_seq_len;
2061 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
2062 };
2063
2064 Ok(max_text_attn)
2065 }
2066
2067 fn non_mapped_max_act_size_elems(
2068 &self,
2069 config: &str,
2070 params: &AutoDeviceMapParams,
2071 ) -> Result<usize> {
2072 let AutoDeviceMapParams::Vision {
2073 max_seq_len: _,
2074 max_batch_size,
2075 max_image_shape,
2076 max_num_images,
2077 } = params
2078 else {
2079 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2080 };
2081
2082 let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2083
2084 let img_seq_len = {
2085 let cfg = &cfg.vision_config;
2086 let grid_t = max_num_images / cfg.temporal_patch_size;
2087 let grid_h = max_image_shape.0 / cfg.patch_size;
2088 let grid_w = max_image_shape.1 / cfg.patch_size;
2089 grid_t * grid_h * grid_w
2090 };
2091
2092 let max_vision_attn = {
2093 let cfg = &cfg.vision_config;
2094 (max_batch_size * max_num_images) * cfg.num_heads * img_seq_len * img_seq_len
2095 };
2096
2097 Ok(max_vision_attn)
2098 }
2099
2100 fn non_mapped_size_in_bytes(
2101 &self,
2102 config: &str,
2103 dtype: DType,
2104 weight_pack_factor: usize,
2105 ) -> Result<usize> {
2106 let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2107 let text_elems = {
2108 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2109 let lm_head = if !cfg.tie_word_embeddings {
2110 cfg.hidden_size * cfg.vocab_size
2111 } else {
2112 0
2113 };
2114 let norm = cfg.hidden_size;
2115 embed_tokens + lm_head + norm
2116 };
2117
2118 let patch_merger = {
2119 let cfg = &cfg.vision_config;
2120 let hidden_size = cfg.embed_dim * cfg.spatial_merge_size.pow(2);
2121
2122 let mlp0 = hidden_size * hidden_size + hidden_size;
2123 let mlp2 = hidden_size * cfg.hidden_size + cfg.hidden_size;
2124
2125 let ln_q = cfg.embed_dim + bias_if!(true, cfg.embed_dim);
2126
2127 mlp0 + mlp2 + ln_q
2128 };
2129
2130 let patch_embed = {
2131 let cfg = &cfg.vision_config;
2132 let conv_cfg = Conv3dConfig {
2133 stride: cfg.patch_size,
2134 ..Default::default()
2135 };
2136 let kernel_sizes = [cfg.temporal_patch_size, cfg.patch_size, cfg.patch_size];
2137 cfg.in_channels * cfg.embed_dim / conv_cfg.groups
2138 * kernel_sizes[0]
2139 * kernel_sizes[1]
2140 * kernel_sizes[2]
2141 };
2142
2143 let encoder_layer = {
2144 let cfg = &cfg.vision_config;
2145 let norm1 = cfg.embed_dim + bias_if!(true, cfg.embed_dim);
2146 let norm2 = cfg.embed_dim + bias_if!(true, cfg.embed_dim);
2147
2148 #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2149 let mlp_hidden_dim = (cfg.embed_dim as f64 * cfg.mlp_ratio) as usize;
2150 let fc1 = cfg.embed_dim * mlp_hidden_dim + mlp_hidden_dim;
2151 let fc2 = cfg.embed_dim * mlp_hidden_dim + cfg.embed_dim;
2152
2153 let qkv = cfg.embed_dim * cfg.embed_dim * 3 + cfg.embed_dim * 3;
2154 let out = cfg.embed_dim * cfg.embed_dim + cfg.embed_dim;
2155
2156 norm1 + norm2 + fc1 + fc2 + qkv + out
2157 };
2158
2159 let elems =
2160 text_elems + patch_merger + patch_embed + encoder_layer * cfg.vision_config.depth;
2161
2162 Ok(elems * dtype.size_in_bytes())
2163 }
2164
2165 fn layer_sizes_in_bytes(
2166 &self,
2167 config: &str,
2168 dtype: DType,
2169 weight_pack_factor: usize,
2170 ) -> Result<Vec<usize>> {
2171 let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2172 let per_layer_elems = {
2173 let input_layernorm = cfg.hidden_size;
2174 let post_attention_layernorm = cfg.hidden_size;
2175
2176 let size_in = cfg.hidden_size;
2177 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2178 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2179 let q_proj = size_in * size_q / weight_pack_factor + size_q;
2180 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
2181 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
2182 let o_proj = size_q * size_in / weight_pack_factor;
2183
2184 let h_size = cfg.hidden_size;
2185 let i_size = cfg.intermediate_size;
2186 let gate_proj = h_size * i_size / weight_pack_factor;
2187 let up_proj = h_size * i_size / weight_pack_factor;
2188 let down_proj = i_size * h_size / weight_pack_factor;
2189
2190 input_layernorm
2191 + post_attention_layernorm
2192 + q_proj
2193 + k_proj
2194 + v_proj
2195 + o_proj
2196 + gate_proj
2197 + up_proj
2198 + down_proj
2199 };
2200 Ok(vec![
2201 per_layer_elems * dtype.size_in_bytes();
2202 cfg.num_hidden_layers
2203 ])
2204 }
2205
2206 fn num_layers(&self, config: &str) -> Result<usize> {
2207 let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2208 Ok(cfg.num_hidden_layers)
2209 }
2210
2211 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2212 let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2213
2214 let cfg = ModelConfigMetadata {
2215 max_seq_len: cfg.max_position_embeddings,
2216 num_layers: cfg.num_hidden_layers,
2217 hidden_size: cfg.hidden_size,
2218 num_kv_heads: cfg.num_key_value_heads,
2219 num_attn_heads: cfg.num_attention_heads,
2220 sliding_window: cfg.sliding_window,
2221 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2222 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2223 };
2224
2225 Ok(Box::new(cfg))
2226 }
2227
2228 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
2229 Some(vec![NonMappedSubModel::Vision])
2230 }
2231}
2232
2233pub struct Idefics3Loader;
2239
2240pub struct Idefics3Prefixer;
2241
2242impl VisionPromptPrefixer for Idefics3Prefixer {
2243 fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
2244 prompt.to_string()
2246 }
2247}
2248
2249impl VisionModelLoader for Idefics3Loader {
2250 fn load(
2251 &self,
2252 config: &str,
2253 vb: ShardedVarBuilder,
2254 normal_loading_metadata: NormalLoadingMetadata,
2255 attention_mechanism: AttentionImplementation,
2256 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2257 let cfg: crate::vision_models::idefics3::Idefics3Config = serde_json::from_str(config)?;
2258 Ok(Box::new(Idefics3Model::new(
2259 &cfg,
2260 vb,
2261 self.is_gptx(config),
2262 normal_loading_metadata,
2263 attention_mechanism,
2264 )?))
2265 }
2266 fn is_gptx(&self, _config: &str) -> bool {
2267 true
2268 }
2269 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2270 let cfg: crate::vision_models::idefics3::Idefics3Config = serde_json::from_str(config)?;
2271 Ok(Box::new(cfg))
2272 }
2273 fn get_processor(
2274 &self,
2275 _model_config: &str,
2276 processor_config: Option<ProcessorConfig>,
2277 preprocessor_config: PreProcessorConfig,
2278 max_edge: Option<u32>,
2279 ) -> Arc<dyn Processor + Send + Sync> {
2280 Arc::new(Idefics3Processor::new(
2281 processor_config.unwrap_or_default(),
2282 preprocessor_config,
2283 max_edge,
2284 ))
2285 }
2286 fn supports_paged_attention(&self, _config: &str) -> bool {
2287 true
2288 }
2289 fn supports_prefix_cacher(&self, _config: &str) -> bool {
2290 true
2291 }
2292 fn prefixer(&self, _config: &str) -> Arc<dyn VisionPromptPrefixer> {
2293 Arc::new(Idefics3Prefixer)
2294 }
2295}
2296
2297impl IsqModelLoader for Idefics3Loader {
2298 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2299 Ok(vec![
2300 Regex::new(r"lm_head\.(weight|bias)$")?,
2301 Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2303 Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2304 Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2305 Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2306 Regex::new(r"model.text_model.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2308 Regex::new(r"model.text_model.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2309 Regex::new(r"model.text_model.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2310 ])
2311 }
2312 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
2313 Ok(vec![
2314 Regex::new(r"lm_head\.(weight|bias)$")?,
2315 Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2317 Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2318 Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2319 Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2320 Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2322 Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2323 Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2324 ])
2341 }
2342}
2343
2344impl DeviceMappedModelLoader for Idefics3Loader {
2345 fn mapped_max_act_size_elems(
2346 &self,
2347 config: &str,
2348 params: &AutoDeviceMapParams,
2349 _prompt_chunksize: usize,
2350 ) -> Result<usize> {
2351 let AutoDeviceMapParams::Vision {
2352 max_seq_len,
2353 max_batch_size,
2354 max_image_shape: _,
2355 max_num_images,
2356 } = params
2357 else {
2358 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2359 };
2360
2361 let cfg: Idefics3Config = serde_json::from_str(config)?;
2362
2363 let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2364 let img_seq_len = (num_patches + 1) * max_num_images;
2365
2366 let max_text_attn = {
2367 let max_seq_len = img_seq_len + max_seq_len;
2369 max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len
2370 };
2371
2372 Ok(max_text_attn)
2373 }
2374
2375 fn non_mapped_max_act_size_elems(
2376 &self,
2377 config: &str,
2378 params: &AutoDeviceMapParams,
2379 ) -> Result<usize> {
2380 let AutoDeviceMapParams::Vision {
2381 max_seq_len: _,
2382 max_batch_size,
2383 max_image_shape: _,
2384 max_num_images,
2385 } = params
2386 else {
2387 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2388 };
2389
2390 let cfg: Idefics3Config = serde_json::from_str(config)?;
2391
2392 let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2393 let img_seq_len = num_patches + 1;
2394
2395 let max_vision_attn = {
2396 let images_factor = 5;
2398
2399 (max_batch_size * images_factor * max_num_images)
2400 * cfg.vision_config.num_attention_heads
2401 * img_seq_len
2402 * img_seq_len
2403 };
2404
2405 Ok(max_vision_attn)
2406 }
2407
2408 fn non_mapped_size_in_bytes(
2409 &self,
2410 config: &str,
2411 dtype: DType,
2412 weight_pack_factor: usize,
2413 ) -> Result<usize> {
2414 let cfg: Idefics3Config = serde_json::from_str(config)?;
2415 let text_elems = {
2416 let cfg = &cfg.text_config;
2417
2418 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2419 let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2420 let norm = cfg.hidden_size;
2421 embed_tokens + lm_head + norm
2422 };
2423
2424 let connector_elems = {
2425 let in_dim = cfg.vision_config.hidden_size * cfg.scale_factor.pow(2);
2426 let out_dim = cfg.text_config.hidden_size;
2427
2428 in_dim * out_dim
2429 };
2430
2431 let vision_transformer = {
2432 let cfg = &cfg.vision_config;
2433
2434 let post_layernorm = cfg.hidden_size;
2435
2436 let conv_config = Conv2dConfig {
2437 stride: cfg.patch_size,
2438 ..Default::default()
2439 };
2440 let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
2441 * cfg.patch_size
2442 * cfg.patch_size;
2443
2444 let num_patches_per_side = cfg.image_size / cfg.patch_size;
2445 let num_patches = num_patches_per_side.pow(2);
2446 let position_embedding = num_patches * cfg.hidden_size;
2447
2448 let layer_elems = {
2449 let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2450 let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2451
2452 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
2453 let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
2454
2455 let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2456 let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2457 let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2458 let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2459
2460 layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
2461 };
2462
2463 post_layernorm
2464 + patch_embedding
2465 + position_embedding
2466 + layer_elems * cfg.num_hidden_layers
2467 };
2468
2469 let elems = text_elems + connector_elems + vision_transformer;
2470
2471 Ok(elems * dtype.size_in_bytes())
2472 }
2473
2474 fn layer_sizes_in_bytes(
2475 &self,
2476 config: &str,
2477 dtype: DType,
2478 weight_pack_factor: usize,
2479 ) -> Result<Vec<usize>> {
2480 let cfg: Idefics3Config = serde_json::from_str(config)?;
2481 let cfg = cfg.text_config;
2482 let per_layer_elems = {
2483 let input_layernorm = cfg.hidden_size;
2484 let post_attention_layernorm = cfg.hidden_size;
2485
2486 let size_in = cfg.hidden_size;
2487 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2488 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2489 let q_proj = size_in * size_q / weight_pack_factor;
2490 let k_proj = size_in * size_kv / weight_pack_factor;
2491 let v_proj = size_in * size_kv / weight_pack_factor;
2492 let o_proj = size_q * size_in / weight_pack_factor;
2493
2494 let h_size = cfg.hidden_size;
2495 let i_size = cfg.intermediate_size;
2496 let gate_proj = h_size * i_size / weight_pack_factor;
2497 let up_proj = h_size * i_size / weight_pack_factor;
2498 let down_proj = i_size * h_size / weight_pack_factor;
2499
2500 input_layernorm
2501 + post_attention_layernorm
2502 + q_proj
2503 + k_proj
2504 + v_proj
2505 + o_proj
2506 + gate_proj
2507 + up_proj
2508 + down_proj
2509 };
2510 Ok(vec![
2511 per_layer_elems * dtype.size_in_bytes();
2512 cfg.num_hidden_layers
2513 ])
2514 }
2515
2516 fn num_layers(&self, config: &str) -> Result<usize> {
2517 let cfg: Idefics3Config = serde_json::from_str(config)?;
2518 Ok(cfg.text_config.num_hidden_layers)
2519 }
2520 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2521 let cfg: Idefics3Config = serde_json::from_str(config)?;
2522 let cfg = &cfg.text_config;
2523
2524 let cfg = ModelConfigMetadata {
2525 max_seq_len: cfg.max_position_embeddings,
2526 num_layers: cfg.num_hidden_layers,
2527 hidden_size: cfg.hidden_size,
2528 num_kv_heads: cfg.num_key_value_heads,
2529 num_attn_heads: cfg.num_attention_heads,
2530 sliding_window: None,
2531 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2532 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2533 };
2534
2535 Ok(Box::new(cfg))
2536 }
2537
2538 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
2539 Some(vec![NonMappedSubModel::Vision])
2540 }
2541}
2542
2543pub struct MiniCpmOLoader;
2549
2550pub struct MiniCpmOPrefixer;
2551
2552impl VisionPromptPrefixer for MiniCpmOPrefixer {
2553 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
2554 format!(
2555 "{}{prompt}",
2556 "(<image>./</image>)".repeat(image_indexes.len())
2557 )
2558 }
2559}
2560
2561impl VisionModelLoader for MiniCpmOLoader {
2562 fn load(
2563 &self,
2564 config: &str,
2565 vb: ShardedVarBuilder,
2566 normal_loading_metadata: NormalLoadingMetadata,
2567 attention_mechanism: AttentionImplementation,
2568 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2569 let cfg: crate::vision_models::minicpmo::MiniCpmOConfig = serde_json::from_str(config)?;
2570 Ok(Box::new(MiniCpmOModel::new(
2571 &cfg,
2572 vb,
2573 self.is_gptx(config),
2574 normal_loading_metadata,
2575 attention_mechanism,
2576 )?))
2577 }
2578 fn is_gptx(&self, _config: &str) -> bool {
2579 true
2580 }
2581 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2582 let cfg: crate::vision_models::minicpmo::MiniCpmOConfig = serde_json::from_str(config)?;
2583 Ok(Box::new(cfg))
2584 }
2585 fn get_processor(
2586 &self,
2587 _model_config: &str,
2588 processor_config: Option<ProcessorConfig>,
2589 preprocessor_config: PreProcessorConfig,
2590 max_edge: Option<u32>,
2591 ) -> Arc<dyn Processor + Send + Sync> {
2592 Arc::new(MiniCpmOProcessor::new(
2593 processor_config.unwrap_or_default(),
2594 preprocessor_config,
2595 max_edge,
2596 ))
2597 }
2598 fn supports_paged_attention(&self, _config: &str) -> bool {
2599 true
2600 }
2601 fn prefixer(&self, _config: &str) -> Arc<dyn VisionPromptPrefixer> {
2602 Arc::new(MiniCpmOPrefixer)
2603 }
2604}
2605
2606impl IsqModelLoader for MiniCpmOLoader {
2607 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2608 Ok(vec![
2609 Regex::new(r"llm.lm_head\.(weight|bias)$")?,
2610 Regex::new(r"llm.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2612 Regex::new(r"llm.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2613 Regex::new(r"llm.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2614 Regex::new(r"llm.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2615 Regex::new(r"llm.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2617 Regex::new(r"llm.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2618 Regex::new(r"llm.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2619 ])
2620 }
2621 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2622 self.isq_layer_regexes(config)
2623 }
2624}
2625
2626impl DeviceMappedModelLoader for MiniCpmOLoader {
2627 fn mapped_max_act_size_elems(
2628 &self,
2629 config: &str,
2630 params: &AutoDeviceMapParams,
2631 _prompt_chunksize: usize,
2632 ) -> Result<usize> {
2633 let AutoDeviceMapParams::Vision {
2634 max_seq_len,
2635 max_batch_size,
2636 max_image_shape: _,
2637 max_num_images,
2638 } = params
2639 else {
2640 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2641 };
2642
2643 let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2644
2645 let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2646 let img_seq_len = (num_patches + 1) * max_num_images;
2647
2648 let max_text_attn = {
2649 let max_seq_len = img_seq_len + max_seq_len;
2651 max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len
2652 };
2653
2654 Ok(max_text_attn)
2655 }
2656
2657 fn non_mapped_max_act_size_elems(
2658 &self,
2659 config: &str,
2660 params: &AutoDeviceMapParams,
2661 ) -> Result<usize> {
2662 let AutoDeviceMapParams::Vision {
2663 max_seq_len: _,
2664 max_batch_size,
2665 max_image_shape: _,
2666 max_num_images,
2667 } = params
2668 else {
2669 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2670 };
2671
2672 let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2673
2674 let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2675 let img_seq_len = num_patches + 1;
2676
2677 let max_vision_attn = {
2678 let images_factor = 5;
2680
2681 (max_batch_size * images_factor * max_num_images)
2682 * cfg.vision_config.num_attention_heads
2683 * img_seq_len
2684 * img_seq_len
2685 };
2686
2687 Ok(max_vision_attn)
2688 }
2689
2690 fn non_mapped_size_in_bytes(
2691 &self,
2692 config: &str,
2693 dtype: DType,
2694 weight_pack_factor: usize,
2695 ) -> Result<usize> {
2696 let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2697 let text_elems = {
2698 let cfg = &cfg.text_config;
2699
2700 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2701 let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2702 let norm = cfg.hidden_size;
2703 embed_tokens + lm_head + norm
2704 };
2705
2706 let vision_transformer = {
2707 let cfg = &cfg.vision_config;
2708
2709 let post_layernorm = cfg.hidden_size;
2710
2711 let conv_config = Conv2dConfig {
2712 stride: cfg.patch_size,
2713 ..Default::default()
2714 };
2715 let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
2716 * cfg.patch_size
2717 * cfg.patch_size;
2718
2719 let num_patches_per_side = cfg.image_size / cfg.patch_size;
2720 let num_patches = num_patches_per_side.pow(2);
2721 let position_embedding = num_patches * cfg.hidden_size;
2722
2723 let layer_elems = {
2724 let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2725 let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2726
2727 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
2728 let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
2729
2730 let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2731 let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2732 let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2733 let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2734
2735 layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
2736 };
2737
2738 post_layernorm
2739 + patch_embedding
2740 + position_embedding
2741 + layer_elems * cfg.num_hidden_layers
2742 };
2743
2744 let elems = text_elems + vision_transformer;
2745
2746 Ok(elems * dtype.size_in_bytes())
2747 }
2748
2749 fn layer_sizes_in_bytes(
2750 &self,
2751 config: &str,
2752 dtype: DType,
2753 weight_pack_factor: usize,
2754 ) -> Result<Vec<usize>> {
2755 let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2756 let cfg = cfg.text_config;
2757 let per_layer_elems = {
2758 let input_layernorm = cfg.hidden_size;
2759 let post_attention_layernorm = cfg.hidden_size;
2760
2761 let size_in = cfg.hidden_size;
2762 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2763 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2764 let q_proj = size_in * size_q / weight_pack_factor;
2765 let k_proj = size_in * size_kv / weight_pack_factor;
2766 let v_proj = size_in * size_kv / weight_pack_factor;
2767 let o_proj = size_q * size_in / weight_pack_factor;
2768
2769 let h_size = cfg.hidden_size;
2770 let i_size = cfg.intermediate_size;
2771 let gate_proj = h_size * i_size / weight_pack_factor;
2772 let up_proj = h_size * i_size / weight_pack_factor;
2773 let down_proj = i_size * h_size / weight_pack_factor;
2774
2775 input_layernorm
2776 + post_attention_layernorm
2777 + q_proj
2778 + k_proj
2779 + v_proj
2780 + o_proj
2781 + gate_proj
2782 + up_proj
2783 + down_proj
2784 };
2785 Ok(vec![
2786 per_layer_elems * dtype.size_in_bytes();
2787 cfg.num_hidden_layers
2788 ])
2789 }
2790
2791 fn num_layers(&self, config: &str) -> Result<usize> {
2792 let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2793 Ok(cfg.text_config.num_hidden_layers)
2794 }
2795 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2796 let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2797 let cfg = &cfg.text_config;
2798
2799 let cfg = ModelConfigMetadata {
2800 max_seq_len: cfg.max_position_embeddings,
2801 num_layers: cfg.num_hidden_layers,
2802 hidden_size: cfg.hidden_size,
2803 num_kv_heads: cfg.num_key_value_heads,
2804 num_attn_heads: cfg.num_attention_heads,
2805 sliding_window: None,
2806 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2807 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2808 };
2809
2810 Ok(Box::new(cfg))
2811 }
2812}
2813
2814pub struct Phi4MMLoader;
2820
2821pub struct Phi4MMPrefixer;
2822
2823impl VisionPromptPrefixer for Phi4MMPrefixer {
2824 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
2825 format!(
2828 "{}{prompt}",
2829 image_indexes
2830 .into_iter()
2831 .map(|image_index| format!("<|image_{}|>", image_index + 1))
2832 .join("")
2833 )
2834 }
2835}
2836
2837impl VisionModelLoader for Phi4MMLoader {
2838 fn load(
2839 &self,
2840 config: &str,
2841 vb: ShardedVarBuilder,
2842 normal_loading_metadata: NormalLoadingMetadata,
2843 attention_mechanism: AttentionImplementation,
2844 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2845 let cfg: crate::vision_models::phi4::Phi4MMConfig = serde_json::from_str(config)?;
2846 Ok(Box::new(Phi4MMModel::new(
2847 &cfg,
2848 vb,
2849 self.is_gptx(config),
2850 normal_loading_metadata,
2851 attention_mechanism,
2852 )?))
2853 }
2854 fn is_gptx(&self, _config: &str) -> bool {
2855 true
2856 }
2857 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2858 let cfg: crate::vision_models::phi4::Phi4MMConfig = serde_json::from_str(config)?;
2859 Ok(Box::new(cfg))
2860 }
2861 fn get_processor(
2862 &self,
2863 _model_config: &str,
2864 processor_config: Option<ProcessorConfig>,
2865 preprocessor_config: PreProcessorConfig,
2866 _max_edge: Option<u32>,
2867 ) -> Arc<dyn Processor + Send + Sync> {
2868 Phi4MMProcessor::new_processor(processor_config, preprocessor_config)
2869 }
2870 fn supports_paged_attention(&self, _config: &str) -> bool {
2871 true
2872 }
2873 fn supports_prefix_cacher(&self, _config: &str) -> bool {
2874 true
2875 }
2876 fn prefixer(&self, _config: &str) -> Arc<dyn VisionPromptPrefixer> {
2877 Arc::new(Phi4MMPrefixer)
2878 }
2879}
2880
2881impl IsqModelLoader for Phi4MMLoader {
2882 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2883 Ok(vec![
2884 Regex::new(r"lm_head\.(weight|bias)$")?,
2885 Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?,
2887 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2888 Regex::new(r"layers\.(\d+)\.mlp\.gate_up_proj\.(weight|bias)$")?,
2890 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2891 ])
2892 }
2893 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2894 self.isq_layer_regexes(config)
2895 }
2896}
2897
2898impl DeviceMappedModelLoader for Phi4MMLoader {
2899 fn mapped_max_act_size_elems(
2900 &self,
2901 config: &str,
2902 params: &AutoDeviceMapParams,
2903 _prompt_chunksize: usize,
2904 ) -> Result<usize> {
2905 let AutoDeviceMapParams::Vision {
2907 max_seq_len,
2908 max_batch_size,
2909 max_image_shape: _,
2910 max_num_images,
2911 } = params
2912 else {
2913 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2914 };
2915
2916 let cfg: Phi4MMConfig = serde_json::from_str(config)?;
2917
2918 let vcfg = &PHI4_MM_VISION_CFG;
2919
2920 let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
2921 let img_seq_len = (num_patches + 1) * max_num_images;
2922
2923 let max_text_attn = {
2924 let max_seq_len = img_seq_len + max_seq_len;
2926 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
2927 };
2928
2929 Ok(max_text_attn)
2930 }
2931
2932 fn non_mapped_max_act_size_elems(
2933 &self,
2934 _config: &str,
2935 params: &AutoDeviceMapParams,
2936 ) -> Result<usize> {
2937 let AutoDeviceMapParams::Vision {
2938 max_seq_len: _,
2939 max_batch_size,
2940 max_image_shape,
2941 max_num_images,
2942 } = params
2943 else {
2944 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2945 };
2946
2947 let vcfg = &PHI4_MM_VISION_CFG;
2948
2949 let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
2950 let img_seq_len = num_patches + 1;
2951
2952 let max_batch_size = max_batch_size
2953 * (max_image_shape
2954 .0
2955 .div_ceil(phi4::inputs_processor::DYHD_BASE_RESOLUTION)
2956 * max_image_shape
2957 .1
2958 .div_ceil(phi4::inputs_processor::DYHD_BASE_RESOLUTION)
2959 + 1);
2960
2961 let max_vision_attn = (max_batch_size * max_num_images)
2962 * vcfg.num_attention_heads
2963 * img_seq_len
2964 * img_seq_len;
2965 let max_qkv = 3
2966 * (max_batch_size
2967 * vcfg.num_attention_heads
2968 * img_seq_len
2969 * (vcfg.hidden_size / vcfg.num_attention_heads));
2970
2971 Ok(max_vision_attn + max_qkv)
2972 }
2973
2974 fn non_mapped_size_in_bytes(
2975 &self,
2976 config: &str,
2977 dtype: DType,
2978 weight_pack_factor: usize,
2979 ) -> Result<usize> {
2980 let cfg: Phi4MMConfig = serde_json::from_str(config)?;
2981 let elems = {
2982 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2983 let lm_head = if !cfg.tie_word_embeddings {
2984 cfg.hidden_size * cfg.vocab_size
2985 } else {
2986 0
2987 };
2988 let norm = cfg.hidden_size;
2989
2990 let image_embed = if let Some(img_embed) = &cfg.embd_layer.image_embd_layer {
2991 let projection_cls = img_embed
2992 .projection_cls
2993 .clone()
2994 .unwrap_or("linear".to_string());
2995 let with_learnable_separator = img_embed.with_learnable_separator.unwrap_or(false);
2996 let use_hd_transform = img_embed.use_hd_transform.unwrap_or(false);
2997 let image_dim_out = PHI4_MM_VISION_CFG.hidden_size;
2998
2999 let proj = match (projection_cls.as_str(), use_hd_transform) {
3000 ("linear", _) => image_dim_out * cfg.hidden_size + cfg.hidden_size,
3001 ("mlp", true) => {
3002 let a = (image_dim_out * 4) * cfg.hidden_size + cfg.hidden_size;
3003 let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3004 a + b
3005 }
3006 ("mlp", false) => {
3007 let a = image_dim_out * cfg.hidden_size + cfg.hidden_size;
3008 let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3009 a + b
3010 }
3011 _ => {
3012 anyhow::bail!("projection_cls=`{projection_cls}` not implemented.");
3013 }
3014 };
3015
3016 let (glb_gn, sub_gn) = if with_learnable_separator {
3017 let glb_gn = image_dim_out * 4;
3018 let sub_gn = image_dim_out * 4;
3019 (glb_gn, sub_gn)
3020 } else {
3021 (0, 0)
3022 };
3023
3024 let vision_transformer = {
3025 let cfg = &PHI4_MM_VISION_CFG;
3026
3027 let post_layernorm = cfg.hidden_size;
3028
3029 let conv_config = Conv2dConfig {
3030 stride: cfg.patch_size,
3031 ..Default::default()
3032 };
3033 let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
3034 * cfg.patch_size
3035 * cfg.patch_size;
3036
3037 let num_patches_per_side = cfg.image_size / cfg.patch_size;
3038 let num_patches = num_patches_per_side.pow(2);
3039 let position_embedding = num_patches * cfg.hidden_size;
3040
3041 let layer_elems = {
3042 let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3043 let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3044
3045 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
3046 let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
3047
3048 let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3049 let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3050 let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3051 let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3052
3053 layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
3054 };
3055
3056 post_layernorm
3057 + patch_embedding
3058 + position_embedding
3059 + layer_elems * cfg.num_hidden_layers
3060 };
3061
3062 proj + glb_gn + sub_gn + vision_transformer
3063 } else {
3064 0
3065 };
3066
3067 embed_tokens + lm_head + norm + image_embed
3068 };
3069
3070 Ok(elems * dtype.size_in_bytes())
3071 }
3072
3073 fn layer_sizes_in_bytes(
3074 &self,
3075 config: &str,
3076 dtype: DType,
3077 weight_pack_factor: usize,
3078 ) -> Result<Vec<usize>> {
3079 let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3080 let per_layer_elems = {
3081 let input_layernorm = cfg.hidden_size;
3082 let post_attention_layernorm = cfg.hidden_size;
3083
3084 let size_in = cfg.hidden_size;
3085 let head_dim = cfg.head_dim();
3086 let op_size =
3087 cfg.num_attention_heads * head_dim + 2 * cfg.num_key_value_heads() * head_dim;
3088 let qkv_proj = size_in * op_size / weight_pack_factor;
3089 let o_proj = (cfg.num_attention_heads * head_dim) * size_in / weight_pack_factor;
3090
3091 let h_size = cfg.hidden_size;
3092 let i_size = cfg.intermediate_size;
3093 let gate_up_proj = h_size * (2 * i_size) / weight_pack_factor;
3094 let down_proj = h_size * i_size / weight_pack_factor;
3095
3096 input_layernorm
3097 + post_attention_layernorm
3098 + qkv_proj
3099 + o_proj
3100 + gate_up_proj
3101 + down_proj
3102 };
3103 Ok(vec![
3104 per_layer_elems * dtype.size_in_bytes();
3105 cfg.num_hidden_layers
3106 ])
3107 }
3108
3109 fn num_layers(&self, config: &str) -> Result<usize> {
3110 let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3111 Ok(cfg.num_hidden_layers)
3112 }
3113
3114 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3115 let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3116
3117 let cfg = ModelConfigMetadata {
3118 max_seq_len: cfg.max_position_embeddings,
3119 num_layers: cfg.num_hidden_layers,
3120 hidden_size: cfg.hidden_size,
3121 num_kv_heads: cfg.num_key_value_heads(),
3122 num_attn_heads: cfg.num_attention_heads,
3123 sliding_window: cfg.sliding_window,
3124 k_head_dim: cfg.head_dim(),
3125 v_head_dim: cfg.head_dim(),
3126 };
3127
3128 Ok(Box::new(cfg))
3129 }
3130
3131 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
3132 Some(vec![NonMappedSubModel::Vision])
3133 }
3134}
3135
3136pub struct Qwen2_5VLLoader;
3142
3143pub struct Qwen2_5VLPrefixer;
3144
3145impl VisionPromptPrefixer for Qwen2_5VLPrefixer {
3146 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
3147 format!(
3148 "{}{prompt}",
3149 format!(
3150 "{}{}{}",
3151 Qwen2_5VLProcessor::VISION_START,
3152 Qwen2_5VLProcessor::IMAGE_PAD,
3153 Qwen2_5VLProcessor::VISION_END
3154 )
3155 .repeat(image_indexes.len())
3156 )
3157 }
3158}
3159
3160impl VisionModelLoader for Qwen2_5VLLoader {
3161 fn load(
3162 &self,
3163 config: &str,
3164 vb: ShardedVarBuilder,
3165 normal_loading_metadata: NormalLoadingMetadata,
3166 attention_mechanism: AttentionImplementation,
3167 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
3168 let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3169 Ok(Box::new(Qwen2_5VLModel::new(
3170 &cfg,
3171 vb,
3172 self.is_gptx(config),
3173 normal_loading_metadata,
3174 attention_mechanism,
3175 )?))
3176 }
3177 fn is_gptx(&self, _config: &str) -> bool {
3178 true
3179 }
3180 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3181 let config: Qwen2_5VLConfig = serde_json::from_str(config)?;
3182 Ok(Box::new(config))
3183 }
3184 fn get_processor(
3185 &self,
3186 _model_config: &str,
3187 _processor_config: Option<ProcessorConfig>,
3188 _preprocessor_config: PreProcessorConfig,
3189 max_edge: Option<u32>,
3190 ) -> Arc<dyn Processor + Send + Sync> {
3191 Arc::new(Qwen2_5VLProcessor::new(max_edge))
3192 }
3193 fn supports_paged_attention(&self, _config: &str) -> bool {
3194 false
3195 }
3196 fn prefixer(&self, _config: &str) -> Arc<dyn VisionPromptPrefixer> {
3197 Arc::new(Qwen2_5VLPrefixer)
3198 }
3199}
3200
3201impl IsqModelLoader for Qwen2_5VLLoader {
3202 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3203 Ok(vec![
3204 Regex::new(r"lm_head\.(weight|bias)$")?,
3205 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3207 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3208 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3209 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3210 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3212 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3213 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3214 ])
3215 }
3216 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3217 self.isq_layer_regexes(config)
3218 }
3219}
3220
3221impl DeviceMappedModelLoader for Qwen2_5VLLoader {
3222 fn mapped_max_act_size_elems(
3223 &self,
3224 config: &str,
3225 params: &AutoDeviceMapParams,
3226 _prompt_chunksize: usize,
3227 ) -> Result<usize> {
3228 let AutoDeviceMapParams::Vision {
3229 max_seq_len,
3230 max_batch_size,
3231 max_image_shape,
3232 max_num_images,
3233 } = params
3234 else {
3235 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3236 };
3237
3238 let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3239
3240 let img_seq_len = {
3241 let cfg = &cfg.vision_config;
3242 let grid_t = max_num_images / cfg.temporal_patch_size;
3243 let grid_h = max_image_shape.0 / cfg.patch_size;
3244 let grid_w = max_image_shape.1 / cfg.patch_size;
3245 grid_t * grid_h * grid_w
3246 };
3247 let img_seq_len = img_seq_len * max_num_images;
3248
3249 let max_text_attn = {
3250 let max_seq_len = img_seq_len + max_seq_len;
3252 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
3253 };
3254
3255 Ok(max_text_attn)
3256 }
3257
3258 fn non_mapped_max_act_size_elems(
3259 &self,
3260 config: &str,
3261 params: &AutoDeviceMapParams,
3262 ) -> Result<usize> {
3263 let AutoDeviceMapParams::Vision {
3264 max_seq_len: _,
3265 max_batch_size,
3266 max_image_shape,
3267 max_num_images,
3268 } = params
3269 else {
3270 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3271 };
3272
3273 let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3274
3275 let img_seq_len = {
3276 let cfg = &cfg.vision_config;
3277 let grid_t = max_num_images / cfg.temporal_patch_size;
3278 let grid_h = max_image_shape.0 / cfg.patch_size;
3279 let grid_w = max_image_shape.1 / cfg.patch_size;
3280 grid_t * grid_h * grid_w
3281 };
3282
3283 let max_vision_attn = {
3284 let cfg = &cfg.vision_config;
3285 (max_batch_size * max_num_images) * cfg.num_heads * img_seq_len * img_seq_len
3286 };
3287
3288 Ok(max_vision_attn)
3289 }
3290
3291 fn non_mapped_size_in_bytes(
3292 &self,
3293 config: &str,
3294 dtype: DType,
3295 weight_pack_factor: usize,
3296 ) -> Result<usize> {
3297 let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3298 let text_elems = {
3299 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3300 let lm_head = if !cfg.tie_word_embeddings {
3301 cfg.hidden_size * cfg.vocab_size
3302 } else {
3303 0
3304 };
3305 let norm = cfg.hidden_size;
3306 embed_tokens + lm_head + norm
3307 };
3308
3309 let patch_merger = {
3310 let cfg = &cfg.vision_config;
3311 let hidden_size = cfg.hidden_size * cfg.spatial_merge_size.pow(2);
3312
3313 let mlp0 = hidden_size * hidden_size + hidden_size;
3314 let mlp2 = hidden_size * cfg.hidden_size + cfg.hidden_size;
3315
3316 let ln_q = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3317
3318 mlp0 + mlp2 + ln_q
3319 };
3320
3321 let patch_embed = {
3322 let cfg = &cfg.vision_config;
3323 let conv_cfg = Conv3dConfig {
3324 stride: cfg.patch_size,
3325 ..Default::default()
3326 };
3327 let kernel_sizes = [cfg.temporal_patch_size, cfg.patch_size, cfg.patch_size];
3328 cfg.in_chans * cfg.hidden_size / conv_cfg.groups
3329 * kernel_sizes[0]
3330 * kernel_sizes[1]
3331 * kernel_sizes[2]
3332 };
3333
3334 let encoder_layer = {
3335 let cfg = &cfg.vision_config;
3336 let norm1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3337 let norm2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3338
3339 #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
3340 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
3341 let fc2 = cfg.hidden_size * cfg.intermediate_size + cfg.hidden_size;
3342
3343 let qkv = cfg.hidden_size * cfg.hidden_size * 3 + cfg.hidden_size * 3;
3344 let out = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3345
3346 norm1 + norm2 + fc1 + fc2 + qkv + out
3347 };
3348
3349 let elems =
3350 text_elems + patch_merger + patch_embed + encoder_layer * cfg.vision_config.depth;
3351
3352 Ok(elems * dtype.size_in_bytes())
3353 }
3354
3355 fn layer_sizes_in_bytes(
3356 &self,
3357 config: &str,
3358 dtype: DType,
3359 weight_pack_factor: usize,
3360 ) -> Result<Vec<usize>> {
3361 let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3362 let per_layer_elems = {
3363 let input_layernorm = cfg.hidden_size;
3364 let post_attention_layernorm = cfg.hidden_size;
3365
3366 let size_in = cfg.hidden_size;
3367 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
3368 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
3369 let q_proj = size_in * size_q / weight_pack_factor + size_q;
3370 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
3371 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
3372 let o_proj = size_q * size_in / weight_pack_factor;
3373
3374 let h_size = cfg.hidden_size;
3375 let i_size = cfg.intermediate_size;
3376 let gate_proj = h_size * i_size / weight_pack_factor;
3377 let up_proj = h_size * i_size / weight_pack_factor;
3378 let down_proj = i_size * h_size / weight_pack_factor;
3379
3380 input_layernorm
3381 + post_attention_layernorm
3382 + q_proj
3383 + k_proj
3384 + v_proj
3385 + o_proj
3386 + gate_proj
3387 + up_proj
3388 + down_proj
3389 };
3390 Ok(vec![
3391 per_layer_elems * dtype.size_in_bytes();
3392 cfg.num_hidden_layers
3393 ])
3394 }
3395
3396 fn num_layers(&self, config: &str) -> Result<usize> {
3397 let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3398 Ok(cfg.num_hidden_layers)
3399 }
3400
3401 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3402 let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3403
3404 let cfg = ModelConfigMetadata {
3405 max_seq_len: cfg.max_position_embeddings,
3406 num_layers: cfg.num_hidden_layers,
3407 hidden_size: cfg.hidden_size,
3408 num_kv_heads: cfg.num_key_value_heads,
3409 num_attn_heads: cfg.num_attention_heads,
3410 sliding_window: cfg.sliding_window,
3411 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3412 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3413 };
3414
3415 Ok(Box::new(cfg))
3416 }
3417
3418 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
3419 Some(vec![NonMappedSubModel::Vision])
3420 }
3421}
3422
3423pub struct Gemma3Loader;
3429
3430pub struct Gemma3Prefixer;
3431
3432impl VisionPromptPrefixer for Gemma3Prefixer {
3433 fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
3434 prompt.to_string()
3435 }
3436}
3437
3438impl VisionModelLoader for Gemma3Loader {
3439 fn load(
3440 &self,
3441 config: &str,
3442 vb: ShardedVarBuilder,
3443 normal_loading_metadata: NormalLoadingMetadata,
3444 attention_mechanism: AttentionImplementation,
3445 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
3446 let cfg: Gemma3Config = serde_json::from_str(config)?;
3447 Ok(Box::new(Gemma3Model::new(
3448 &cfg,
3449 vb,
3450 self.is_gptx(config),
3451 normal_loading_metadata,
3452 attention_mechanism,
3453 )?))
3454 }
3455 fn is_gptx(&self, _config: &str) -> bool {
3456 true
3457 }
3458 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3459 let config: Gemma3Config = serde_json::from_str(config)?;
3460 Ok(Box::new(config))
3461 }
3462 fn get_processor(
3463 &self,
3464 config: &str,
3465 processor_config: Option<ProcessorConfig>,
3466 _preprocessor_config: PreProcessorConfig,
3467 _max_edge: Option<u32>,
3468 ) -> Arc<dyn Processor + Send + Sync> {
3469 let config: Gemma3Config = serde_json::from_str(config).unwrap();
3470 Arc::new(Gemma3Processor::new(
3472 processor_config.unwrap_or_default(),
3473 matches!(config, Gemma3Config::WithVision { .. }),
3474 ))
3475 }
3476 fn supports_paged_attention(&self, _config: &str) -> bool {
3477 true
3478 }
3479 fn supports_prefix_cacher(&self, _config: &str) -> bool {
3480 true
3481 }
3482 fn prefixer(&self, _config: &str) -> Arc<dyn VisionPromptPrefixer> {
3483 Arc::new(Gemma3Prefixer)
3484 }
3485}
3486
3487impl IsqModelLoader for Gemma3Loader {
3488 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3489 Ok(vec![
3490 Regex::new(r"lm_head\.(weight|bias)$")?,
3491 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3493 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3494 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3495 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3496 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3498 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3499 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3500 ])
3501 }
3502 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
3503 Ok(vec![
3504 Regex::new(r"lm_head\.(weight|bias)$")?,
3505 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3507 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3508 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3509 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3510 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3512 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3513 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3514 ])
3515 }
3516}
3517
3518impl DeviceMappedModelLoader for Gemma3Loader {
3519 fn mapped_max_act_size_elems(
3520 &self,
3521 config: &str,
3522 params: &AutoDeviceMapParams,
3523 prompt_chunksize: usize,
3524 ) -> Result<usize> {
3525 let AutoDeviceMapParams::Vision {
3526 max_seq_len,
3527 max_batch_size,
3528 max_image_shape: _,
3529 max_num_images,
3530 } = params
3531 else {
3532 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3533 };
3534
3535 let cfg: Gemma3Config = serde_json::from_str(config)?;
3536
3537 match cfg {
3538 Gemma3Config::Text(text_config) => Ok(max_batch_size
3539 * text_config.num_attention_heads
3540 * prompt_chunksize
3541 * prompt_chunksize),
3542 Gemma3Config::WithVision {
3543 text_config,
3544 vision_config,
3545 ..
3546 } => {
3547 let num_patches = (vision_config.image_size / vision_config.patch_size).pow(2);
3548 let img_seq_len = (num_patches + 1) * max_num_images;
3549
3550 let max_text_attn = {
3551 let max_seq_len = img_seq_len + *max_seq_len;
3553 max_batch_size * text_config.num_attention_heads * max_seq_len * max_seq_len
3554 };
3555 Ok(max_text_attn)
3556 }
3557 }
3558 }
3559
3560 fn non_mapped_max_act_size_elems(
3561 &self,
3562 config: &str,
3563 params: &AutoDeviceMapParams,
3564 ) -> Result<usize> {
3565 let AutoDeviceMapParams::Vision {
3566 max_seq_len: _,
3567 max_batch_size,
3568 max_image_shape: _,
3569 max_num_images,
3570 } = params
3571 else {
3572 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3573 };
3574
3575 let cfg: Gemma3Config = serde_json::from_str(config)?;
3576
3577 match cfg {
3578 Gemma3Config::WithVision { vision_config, .. } => {
3579 let num_patches = (vision_config.image_size / vision_config.patch_size).pow(2);
3580 let img_seq_len = num_patches + 1;
3581
3582 let max_vision_attn = {
3583 (max_batch_size * max_num_images)
3584 * vision_config.num_attention_heads
3585 * img_seq_len
3586 * img_seq_len
3587 };
3588
3589 Ok(max_vision_attn)
3590 }
3591 Gemma3Config::Text(_) => Ok(0),
3592 }
3593 }
3594
3595 fn non_mapped_size_in_bytes(
3596 &self,
3597 config: &str,
3598 dtype: DType,
3599 weight_pack_factor: usize,
3600 ) -> Result<usize> {
3601 let cfg: Gemma3Config = serde_json::from_str(config)?;
3602
3603 let text_elems = {
3604 let cfg = match &cfg {
3605 Gemma3Config::Text(cfg) => cfg,
3606 Gemma3Config::WithVision { text_config, .. } => text_config,
3607 };
3608 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3609 let lm_head = if !cfg.tie_word_embeddings {
3610 cfg.hidden_size * cfg.vocab_size
3611 } else {
3612 0
3613 };
3614 let norm = cfg.hidden_size;
3615 embed_tokens + lm_head + norm
3616 };
3617
3618 let vision_transformer = if let Gemma3Config::WithVision {
3619 vision_config: cfg, ..
3620 } = &cfg
3621 {
3622 let post_layernorm = cfg.hidden_size;
3623
3624 let conv_config = Conv2dConfig {
3625 stride: cfg.patch_size,
3626 ..Default::default()
3627 };
3628 let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
3629 * cfg.patch_size
3630 * cfg.patch_size;
3631
3632 let num_patches_per_side = cfg.image_size / cfg.patch_size;
3633 let num_patches = num_patches_per_side.pow(2);
3634 let position_embedding = num_patches * cfg.hidden_size;
3635
3636 let layer_elems = {
3637 let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3638 let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3639
3640 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
3641 let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
3642
3643 let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3644 let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3645 let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3646 let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3647
3648 layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
3649 };
3650
3651 post_layernorm
3652 + patch_embedding
3653 + position_embedding
3654 + layer_elems * cfg.num_hidden_layers
3655 } else {
3656 0
3657 };
3658
3659 let elems = text_elems + vision_transformer;
3660
3661 Ok(elems * dtype.size_in_bytes())
3662 }
3663
3664 fn layer_sizes_in_bytes(
3665 &self,
3666 config: &str,
3667 dtype: DType,
3668 weight_pack_factor: usize,
3669 ) -> Result<Vec<usize>> {
3670 let cfg: Gemma3Config = serde_json::from_str(config)?;
3671
3672 let txt_cfg = match &cfg {
3673 Gemma3Config::Text(cfg) => cfg,
3674 Gemma3Config::WithVision { text_config, .. } => text_config,
3675 };
3676 let per_layer_elems = {
3677 let cfg = txt_cfg;
3678
3679 let input_layernorm = cfg.hidden_size;
3680 let post_attention_layernorm = cfg.hidden_size;
3681
3682 let size_in = cfg.hidden_size;
3683 let size_q = cfg.head_dim * cfg.num_attention_heads;
3684 let size_kv = cfg.head_dim * cfg.num_key_value_heads;
3685 let q_proj =
3686 size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
3687 let k_proj =
3688 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
3689 let v_proj =
3690 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
3691 let o_proj =
3692 size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
3693
3694 let h_size = cfg.hidden_size;
3695 let i_size = cfg.intermediate_size;
3696 let gate_proj = h_size * i_size / weight_pack_factor;
3697 let up_proj = h_size * i_size / weight_pack_factor;
3698 let down_proj = i_size * h_size / weight_pack_factor;
3699
3700 input_layernorm
3701 + post_attention_layernorm
3702 + q_proj
3703 + k_proj
3704 + v_proj
3705 + o_proj
3706 + gate_proj
3707 + up_proj
3708 + down_proj
3709 };
3710 Ok(vec![
3711 per_layer_elems * dtype.size_in_bytes();
3712 txt_cfg.num_hidden_layers
3713 ])
3714 }
3715
3716 fn num_layers(&self, config: &str) -> Result<usize> {
3717 let cfg: Gemma3Config = serde_json::from_str(config)?;
3718
3719 let txt_cfg = match &cfg {
3720 Gemma3Config::Text(cfg) => cfg,
3721 Gemma3Config::WithVision { text_config, .. } => text_config,
3722 };
3723
3724 Ok(txt_cfg.num_hidden_layers)
3725 }
3726
3727 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3728 let cfg: Gemma3Config = serde_json::from_str(config)?;
3729
3730 let cfg = match &cfg {
3731 Gemma3Config::Text(cfg) => cfg,
3732 Gemma3Config::WithVision { text_config, .. } => text_config,
3733 };
3734
3735 let cfg = ModelConfigMetadata {
3736 max_seq_len: cfg.max_position_embeddings,
3737 num_layers: cfg.num_hidden_layers,
3738 hidden_size: cfg.hidden_size,
3739 num_kv_heads: cfg.num_key_value_heads,
3740 num_attn_heads: cfg.num_attention_heads,
3741 sliding_window: None, k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3743 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3744 };
3745
3746 Ok(Box::new(cfg))
3747 }
3748
3749 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
3750 Some(vec![NonMappedSubModel::Vision])
3751 }
3752}
3753
3754pub struct Mistral3Loader;
3760
3761pub struct Mistral3Prefixer;
3762
3763impl VisionPromptPrefixer for Mistral3Prefixer {
3764 fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
3765 prompt.to_string()
3766 }
3767}
3768
3769impl VisionModelLoader for Mistral3Loader {
3770 fn load(
3771 &self,
3772 config: &str,
3773 vb: ShardedVarBuilder,
3774 normal_loading_metadata: NormalLoadingMetadata,
3775 attention_mechanism: AttentionImplementation,
3776 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
3777 let cfg: crate::vision_models::mistral3::Mistral3Config = serde_json::from_str(config)?;
3778 Ok(Box::new(Mistral3Model::new(
3779 &cfg,
3780 vb,
3781 self.is_gptx(config),
3782 normal_loading_metadata,
3783 attention_mechanism,
3784 )?))
3785 }
3786 fn is_gptx(&self, _config: &str) -> bool {
3787 true
3788 }
3789 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3790 let cfg: crate::vision_models::mistral3::Mistral3Config = serde_json::from_str(config)?;
3791 Ok(Box::new(cfg))
3792 }
3793 fn get_processor(
3794 &self,
3795 _model_config: &str,
3796 processor_config: Option<ProcessorConfig>,
3797 _preprocessor_config: PreProcessorConfig,
3798 _max_edge: Option<u32>,
3799 ) -> Arc<dyn Processor + Send + Sync> {
3800 Arc::new(Mistral3Processor::new(processor_config.unwrap_or_default()))
3801 }
3802 fn supports_paged_attention(&self, _config: &str) -> bool {
3803 true
3804 }
3805 fn supports_prefix_cacher(&self, _config: &str) -> bool {
3806 true
3807 }
3808 fn prefixer(&self, _config: &str) -> Arc<dyn VisionPromptPrefixer> {
3809 Arc::new(Mistral3Prefixer)
3810 }
3811}
3812
3813impl IsqModelLoader for Mistral3Loader {
3814 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3815 Ok(vec![
3816 Regex::new(r"lm_head\.(weight|bias)$")?,
3817 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3819 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3820 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3821 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3822 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3824 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3825 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3826 ])
3827 }
3828 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
3829 Ok(vec![
3830 Regex::new(r"lm_head\.(weight|bias)$")?,
3831 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3833 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3834 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3835 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3836 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3838 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3839 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3840 ])
3841 }
3842}
3843
3844#[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
3845impl DeviceMappedModelLoader for Mistral3Loader {
3846 fn mapped_max_act_size_elems(
3847 &self,
3848 config: &str,
3849 params: &AutoDeviceMapParams,
3850 _prompt_chunksize: usize,
3851 ) -> Result<usize> {
3852 let cfg: Mistral3Config = serde_json::from_str(config)?;
3853 let vcfg = &cfg.vision_config;
3854 let tcfg = &cfg.text_config;
3855
3856 let AutoDeviceMapParams::Vision {
3857 max_seq_len,
3858 max_batch_size,
3859 max_image_shape: (mut height, mut width),
3860 max_num_images,
3861 } = params
3862 else {
3863 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3864 };
3865
3866 let img_seq_len = {
3867 let (max_height, max_width) = (1540, 1540);
3871 let ratio = (height as f64 / max_height as f64).max(width as f64 / max_width as f64);
3872 if ratio > 1. {
3873 height = (height as f64 / ratio).floor() as usize;
3874 width = (width as f64 / ratio).floor() as usize;
3875 }
3876
3877 let num_height_tokens = (height - 1) / vcfg.patch_size + 1;
3878 let num_width_tokens = (width - 1) / vcfg.patch_size + 1;
3879
3880 height = num_height_tokens * vcfg.patch_size;
3881 width = num_width_tokens * vcfg.patch_size;
3882
3883 let num_height_tokens = height / vcfg.patch_size;
3884 let num_width_tokens = width / vcfg.patch_size;
3885
3886 (num_width_tokens + 1) * num_height_tokens
3887 };
3888
3889 let max_seq_len = img_seq_len * max_num_images + *max_seq_len;
3891 Ok(max_batch_size * tcfg.num_attention_heads * max_seq_len * max_seq_len)
3892 }
3893
3894 fn non_mapped_max_act_size_elems(
3895 &self,
3896 config: &str,
3897 params: &AutoDeviceMapParams,
3898 ) -> Result<usize> {
3899 let cfg: Mistral3Config = serde_json::from_str(config)?;
3900 let cfg = &cfg.vision_config;
3901
3902 let AutoDeviceMapParams::Vision {
3903 max_seq_len: _,
3904 max_batch_size,
3905 max_image_shape: (mut height, mut width),
3906 max_num_images,
3907 } = params
3908 else {
3909 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3910 };
3911
3912 let img_seq_len = {
3913 let (max_height, max_width) = (1540, 1540);
3917 let ratio = (height as f64 / max_height as f64).max(width as f64 / max_width as f64);
3918 if ratio > 1. {
3919 height = (height as f64 / ratio).floor() as usize;
3920 width = (width as f64 / ratio).floor() as usize;
3921 }
3922
3923 let num_height_tokens = (height - 1) / cfg.patch_size + 1;
3924 let num_width_tokens = (width - 1) / cfg.patch_size + 1;
3925
3926 height = num_height_tokens * cfg.patch_size;
3927 width = num_width_tokens * cfg.patch_size;
3928
3929 let num_height_tokens = height / cfg.patch_size;
3930 let num_width_tokens = width / cfg.patch_size;
3931
3932 (num_width_tokens + 1) * num_height_tokens
3933 };
3934
3935 Ok((max_batch_size * max_num_images) * cfg.num_attention_heads * img_seq_len * img_seq_len)
3936 }
3937
3938 fn non_mapped_size_in_bytes(
3939 &self,
3940 config: &str,
3941 dtype: DType,
3942 weight_pack_factor: usize,
3943 ) -> Result<usize> {
3944 let cfg: Mistral3Config = serde_json::from_str(config)?;
3945
3946 let text_elems = {
3947 let cfg = &cfg.text_config;
3948
3949 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3950 let lm_head = if !cfg.tie_word_embeddings {
3951 cfg.hidden_size * cfg.vocab_size
3952 } else {
3953 0
3954 };
3955 let norm = cfg.hidden_size;
3956 embed_tokens + lm_head + norm
3957 };
3958
3959 let vision_elems = {
3960 let cfg = &cfg.vision_config;
3961
3962 let patch_embed = {
3963 let conv_cfg = Conv2dConfig {
3964 stride: cfg.patch_size,
3965 ..Default::default()
3966 };
3967 cfg.num_channels * cfg.hidden_size / conv_cfg.groups
3968 * cfg.patch_size
3969 * cfg.patch_size
3970 * cfg.patch_size
3971 };
3972 let ln_pre = cfg.hidden_size;
3973 let vision_layer = {
3974 let attn_norm = cfg.hidden_size;
3975 let ffn_norm = cfg.hidden_size;
3976
3977 let gate = cfg.hidden_size * cfg.intermediate_size;
3978 let up = cfg.hidden_size * cfg.intermediate_size;
3979 let down = cfg.hidden_size * cfg.intermediate_size;
3980
3981 let q = cfg.hidden_size * cfg.hidden_size;
3982 let k = cfg.hidden_size * cfg.hidden_size;
3983 let v = cfg.hidden_size * cfg.hidden_size;
3984 let o = cfg.hidden_size * cfg.hidden_size;
3985
3986 attn_norm + ffn_norm + gate + up + down + q + k + v + o
3987 };
3988
3989 patch_embed + ln_pre + vision_layer * cfg.num_hidden_layers
3990 };
3991
3992 let elems = text_elems + vision_elems;
3993
3994 Ok(elems * dtype.size_in_bytes())
3995 }
3996
3997 fn layer_sizes_in_bytes(
3998 &self,
3999 config: &str,
4000 dtype: DType,
4001 weight_pack_factor: usize,
4002 ) -> Result<Vec<usize>> {
4003 let cfg: Mistral3Config = serde_json::from_str(config)?;
4004 let cfg = &cfg.text_config;
4005
4006 let per_layer_elems = {
4007 let input_layernorm = cfg.hidden_size;
4008 let post_attention_layernorm = cfg.hidden_size;
4009
4010 let size_in = cfg.hidden_size;
4011 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
4012 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
4013 let q_proj = size_in * size_q / weight_pack_factor;
4014 let k_proj = size_in * size_kv / weight_pack_factor;
4015 let v_proj = size_in * size_kv / weight_pack_factor;
4016 let o_proj = size_q * size_in / weight_pack_factor;
4017
4018 let h_size = cfg.hidden_size;
4019 let i_size = cfg.intermediate_size;
4020 let gate_proj = h_size * i_size / weight_pack_factor;
4021 let up_proj = h_size * i_size / weight_pack_factor;
4022 let down_proj = i_size * h_size / weight_pack_factor;
4023
4024 input_layernorm
4025 + post_attention_layernorm
4026 + q_proj
4027 + k_proj
4028 + v_proj
4029 + o_proj
4030 + gate_proj
4031 + up_proj
4032 + down_proj
4033 };
4034 Ok(vec![
4035 per_layer_elems * dtype.size_in_bytes();
4036 cfg.num_hidden_layers
4037 ])
4038 }
4039
4040 fn num_layers(&self, config: &str) -> Result<usize> {
4041 let cfg: Mistral3Config = serde_json::from_str(config)?;
4042 let cfg = &cfg.text_config;
4043 Ok(cfg.num_hidden_layers)
4044 }
4045
4046 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4047 let cfg: Mistral3Config = serde_json::from_str(config)?;
4048 let cfg = &cfg.text_config;
4049
4050 let cfg = ModelConfigMetadata {
4051 max_seq_len: cfg.max_position_embeddings,
4052 num_layers: cfg.num_hidden_layers,
4053 hidden_size: cfg.hidden_size,
4054 num_kv_heads: cfg.num_key_value_heads,
4055 num_attn_heads: cfg.num_attention_heads,
4056 sliding_window: cfg.sliding_window,
4057 k_head_dim: cfg.head_dim(),
4058 v_head_dim: cfg.head_dim(),
4059 };
4060
4061 Ok(Box::new(cfg))
4062 }
4063
4064 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
4065 Some(vec![NonMappedSubModel::Vision])
4066 }
4067}
4068
4069pub struct VLlama4Loader;
4075
4076pub struct VLlama4Prefixer;
4077
4078impl VisionPromptPrefixer for VLlama4Prefixer {
4079 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
4080 format!(
4081 "{}{prompt}",
4082 llama4::IMAGE_TOKEN.repeat(image_indexes.len())
4083 )
4084 }
4085}
4086
4087impl VisionModelLoader for VLlama4Loader {
4088 fn load(
4089 &self,
4090 config: &str,
4091 vb: ShardedVarBuilder,
4092 normal_loading_metadata: NormalLoadingMetadata,
4093 attention_mechanism: AttentionImplementation,
4094 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
4095 let cfg: crate::vision_models::llama4::Llama4Config = serde_json::from_str(config)?;
4096 Ok(Box::new(Llama4Model::new(
4097 &cfg,
4098 vb,
4099 self.is_gptx(config),
4100 normal_loading_metadata,
4101 attention_mechanism,
4102 )?))
4103 }
4104 fn is_gptx(&self, _config: &str) -> bool {
4105 false
4106 }
4107 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
4108 let cfg: crate::vision_models::llama4::Llama4Config = serde_json::from_str(config)?;
4109 Ok(Box::new(cfg))
4110 }
4111 fn get_processor(
4112 &self,
4113 _model_config: &str,
4114 processor_config: Option<ProcessorConfig>,
4115 _preprocessor_config: PreProcessorConfig,
4116 _max_edge: Option<u32>,
4117 ) -> Arc<dyn Processor + Send + Sync> {
4118 Arc::new(Llama4Processor::new(&processor_config.unwrap()))
4119 }
4120 fn supports_paged_attention(&self, _config: &str) -> bool {
4121 true
4122 }
4123 fn prefixer(&self, _config: &str) -> Arc<dyn VisionPromptPrefixer> {
4124 Arc::new(VLlama4Prefixer)
4125 }
4126}
4127
4128impl IsqModelLoader for VLlama4Loader {
4129 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
4130 Ok(vec![
4131 Regex::new(r"lm_head\.(weight|bias)$")?,
4132 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4134 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4135 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4136 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4137 Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.gate_up_proj\.(weight|bias)$")?,
4139 Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.gate_proj\.(weight|bias)$")?,
4140 Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.up_proj\.(weight|bias)$")?,
4141 Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.down_proj\.(weight|bias)$")?,
4142 Regex::new(r"layers\.(\d+)\.feed_forward\.router\.(weight|bias)$")?,
4143 Regex::new(r"layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$")?,
4144 Regex::new(r"layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$")?,
4145 Regex::new(r"layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$")?,
4146 Regex::new(r"layers\.(\d+)\.feed_forward\.gate_proj\.(weight|bias)$")?,
4148 Regex::new(r"layers\.(\d+)\.feed_forward\.up_proj\.(weight|bias)$")?,
4149 Regex::new(r"layers\.(\d+)\.feed_forward\.down_proj\.(weight|bias)$")?,
4150 ])
4151 }
4152 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
4153 Ok(vec![
4154 Regex::new(r"lm_head\.(weight|bias)$")?,
4155 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4157 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4158 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4159 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4160 Regex::new(
4162 r"language_model\.model\.layers\.(\d+)\.feed_forward\.experts\.gate_up_proj\.(weight|bias)$",
4163 )?,
4164 Regex::new(
4165 r"language_model\.model\.layers\.(\d+)\.feed_forward\.experts\.gate_proj\.(weight|bias)$",
4166 )?,
4167 Regex::new(
4168 r"language_model\.model\.layers\.(\d+)\.feed_forward\.experts\.up_proj\.(weight|bias)$",
4169 )?,
4170 Regex::new(
4171 r"language_model\.model\.layers\.(\d+)\.feed_forward\.experts\.down_proj\.(weight|bias)$",
4172 )?,
4173 Regex::new(
4174 r"language_model\.model\.layers\.(\d+)\.feed_forward\.router\.(weight|bias)$",
4175 )?,
4176 Regex::new(
4177 r"language_model\.model\.layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$",
4178 )?,
4179 Regex::new(
4180 r"language_model\.model\.layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$",
4181 )?,
4182 Regex::new(
4183 r"language_model\.model\.layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$",
4184 )?,
4185 Regex::new(
4187 r"language_model\.model\.layers\.(\d+)\.feed_forward\.gate_proj\.(weight|bias)$",
4188 )?,
4189 Regex::new(
4190 r"language_model\.model\.layers\.(\d+)\.feed_forward\.up_proj\.(weight|bias)$",
4191 )?,
4192 Regex::new(
4193 r"language_model\.model\.layers\.(\d+)\.feed_forward\.down_proj\.(weight|bias)$",
4194 )?,
4195 ])
4196 }
4197}
4198
4199impl VLlama4Loader {
4200 #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
4203 fn run_dummy_processing(
4204 &self,
4205 cfg: &Llama4Config,
4206 height: usize,
4207 width: usize,
4208 max_num_images: usize,
4209 max_batch_size: usize,
4210 ) -> Result<(usize, usize)> {
4211 let cfg = &cfg.vision_config;
4212
4213 let img_processor =
4214 Llama4ImageProcessor::new(Some(cfg.patch_size), Some(cfg.pixel_shuffle_ratio));
4215 let image = DynamicImage::new(width as u32, height as u32, ColorType::Rgb8);
4216 let res = img_processor.preprocess(
4217 vec![image; max_num_images],
4218 vec![],
4219 &PreProcessorConfig::default(),
4220 &Device::Cpu,
4221 (max_batch_size, max_num_images),
4222 )?;
4223
4224 let pixels_batch_size = res.pixel_values.dim(0)?;
4225 let pixels_max_batch_size = pixels_batch_size * max_batch_size;
4226
4227 let (image_h, image_w) = (
4228 res.pixel_values.dim(D::Minus2).unwrap(),
4229 res.pixel_values.dim(D::Minus1).unwrap(),
4230 );
4231 let num_patches_per_chunk = (image_h / img_processor.patch_size)
4232 * (image_w / img_processor.patch_size)
4233 / img_processor.downsample_ratio;
4234
4235 Ok((
4236 pixels_max_batch_size,
4237 num_patches_per_chunk * pixels_max_batch_size,
4238 ))
4239 }
4240}
4241
4242impl DeviceMappedModelLoader for VLlama4Loader {
4243 fn mapped_max_act_size_elems(
4244 &self,
4245 config: &str,
4246 params: &AutoDeviceMapParams,
4247 _prompt_chunksize: usize,
4248 ) -> Result<usize> {
4249 let AutoDeviceMapParams::Vision {
4250 max_seq_len,
4251 max_batch_size,
4252 max_image_shape: (height, width),
4253 max_num_images,
4254 } = params
4255 else {
4256 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4257 };
4258
4259 let cfg: Llama4Config = serde_json::from_str(config)?;
4260
4261 let (_pixels_batch_size, num_text_image_toks) =
4262 self.run_dummy_processing(&cfg, *height, *width, *max_num_images, *max_batch_size)?;
4263
4264 let max_seq_len = max_seq_len + num_text_image_toks;
4265
4266 Ok(max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len)
4267 }
4268 fn non_mapped_max_act_size_elems(
4269 &self,
4270 config: &str,
4271 params: &AutoDeviceMapParams,
4272 ) -> Result<usize> {
4273 let AutoDeviceMapParams::Vision {
4274 max_seq_len: _,
4275 max_batch_size,
4276 max_image_shape: (height, width),
4277 max_num_images,
4278 } = params
4279 else {
4280 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4281 };
4282
4283 let cfg: Llama4Config = serde_json::from_str(config)?;
4284
4285 let (pixels_batch_size, _num_text_image_toks) =
4286 self.run_dummy_processing(&cfg, *height, *width, *max_num_images, *max_batch_size)?;
4287 let max_seq_len = cfg.vision_config.num_patches();
4288
4289 Ok((max_batch_size * pixels_batch_size)
4290 * cfg.vision_config.num_attention_heads
4291 * max_seq_len
4292 * max_seq_len)
4293 }
4294
4295 fn non_mapped_size_in_bytes(
4296 &self,
4297 config: &str,
4298 dtype: DType,
4299 weight_pack_factor: usize,
4300 ) -> Result<usize> {
4301 let cfg: Llama4Config = serde_json::from_str(config)?;
4302 let tcfg = &cfg.text_config;
4303
4304 let text_elems = {
4305 let embed_tokens = tcfg.hidden_size * tcfg.vocab_size / weight_pack_factor;
4306 let lm_head = if !tcfg.tie_word_embeddings {
4307 tcfg.hidden_size * tcfg.vocab_size
4308 } else {
4309 0
4310 };
4311 let norm = tcfg.hidden_size;
4312 embed_tokens + lm_head + norm
4313 };
4314
4315 let vision_elems = {
4316 let cfg = &cfg.vision_config;
4317
4318 let num_patches = cfg.num_patches();
4319
4320 let unfold_elems =
4321 (cfg.num_channels * cfg.patch_size * cfg.patch_size) * cfg.hidden_size;
4322 let class_embeddng_elems = cfg.hidden_size;
4323 let positional_embedding_vlm_elems = num_patches * cfg.hidden_size;
4324 let layernorm_pre_elems = cfg.hidden_size;
4325 let layernorm_post_elems = cfg.hidden_size;
4326
4327 let pixel_shuffle_elems = cfg.intermediate_size * cfg.projector_input_dim
4328 / weight_pack_factor
4329 + cfg.projector_input_dim * cfg.projector_output_dim / weight_pack_factor;
4330
4331 let encoder_layer = {
4332 let input_layernorm = cfg.hidden_size + cfg.hidden_size;
4333 let post_attention_layernorm = cfg.hidden_size + cfg.hidden_size;
4334
4335 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
4336 let q_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
4337 / weight_pack_factor
4338 + cfg.num_attention_heads * head_dim;
4339 let k_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
4340 / weight_pack_factor
4341 + cfg.num_attention_heads * head_dim;
4342 let v_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
4343 / weight_pack_factor
4344 + cfg.num_attention_heads * head_dim;
4345 let o_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
4346 / weight_pack_factor
4347 + cfg.num_attention_heads * head_dim;
4348
4349 let fc1 = (cfg.hidden_size * cfg.intermediate_size) / weight_pack_factor
4350 + cfg.intermediate_size;
4351 let fc2 = (cfg.intermediate_size * cfg.hidden_size) / weight_pack_factor
4352 + cfg.hidden_size;
4353
4354 input_layernorm
4355 + post_attention_layernorm
4356 + q_proj
4357 + k_proj
4358 + v_proj
4359 + o_proj
4360 + fc1
4361 + fc2
4362 };
4363
4364 unfold_elems
4365 + class_embeddng_elems
4366 + positional_embedding_vlm_elems
4367 + layernorm_post_elems
4368 + layernorm_pre_elems
4369 + pixel_shuffle_elems
4370 + encoder_layer * cfg.num_hidden_layers
4371 };
4372
4373 let elems = text_elems + vision_elems;
4374
4375 Ok(elems * dtype.size_in_bytes())
4376 }
4377
4378 fn layer_sizes_in_bytes(
4379 &self,
4380 config: &str,
4381 dtype: DType,
4382 weight_pack_factor: usize,
4383 ) -> Result<Vec<usize>> {
4384 let cfg: Llama4Config = serde_json::from_str(config)?;
4385 let tcfg = &cfg.text_config;
4386
4387 let mut per_layer_elems = Vec::new();
4388
4389 for layer_idx in 0..tcfg.num_hidden_layers {
4390 let input_layernorm = tcfg.hidden_size;
4391 let post_attention_layernorm = tcfg.hidden_size;
4392
4393 let size_in = tcfg.hidden_size;
4394 let size_q = (tcfg.hidden_size / tcfg.num_attention_heads) * tcfg.num_attention_heads;
4395 let size_kv = (tcfg.hidden_size / tcfg.num_attention_heads) * tcfg.num_key_value_heads;
4396 let q_proj = size_in * size_q / weight_pack_factor;
4397 let k_proj = size_in * size_kv / weight_pack_factor;
4398 let v_proj = size_in * size_kv / weight_pack_factor;
4399 let o_proj = size_q * size_in / weight_pack_factor;
4400
4401 let use_moe = tcfg.moe_layers().contains(&layer_idx);
4402 let moe_block = if use_moe {
4403 let h_size = tcfg.hidden_size;
4404 let i_size = tcfg.intermediate_size;
4405 let gate_proj = tcfg.num_local_experts * h_size * i_size / weight_pack_factor;
4406 let up_proj = tcfg.num_local_experts * h_size * i_size / weight_pack_factor;
4407 let down_proj = tcfg.num_local_experts * i_size * h_size / weight_pack_factor;
4408
4409 gate_proj + up_proj + down_proj
4410 } else {
4411 let h_size = tcfg.hidden_size;
4412 let i_size = tcfg.intermediate_size_mlp;
4413 let gate_proj = h_size * i_size / weight_pack_factor;
4414 let up_proj = h_size * i_size / weight_pack_factor;
4415 let down_proj = i_size * h_size / weight_pack_factor;
4416
4417 gate_proj + up_proj + down_proj
4418 };
4419
4420 per_layer_elems.push(
4421 input_layernorm
4422 + post_attention_layernorm
4423 + q_proj
4424 + k_proj
4425 + v_proj
4426 + o_proj
4427 + moe_block,
4428 );
4429 }
4430
4431 Ok(per_layer_elems
4432 .into_iter()
4433 .map(|x| x * dtype.size_in_bytes())
4434 .collect())
4435 }
4436
4437 fn num_layers(&self, config: &str) -> Result<usize> {
4438 let cfg: Llama4Config = serde_json::from_str(config)?;
4439 Ok(cfg.text_config.num_hidden_layers)
4440 }
4441
4442 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4443 let cfg: Llama4Config = serde_json::from_str(config)?;
4444 let cfg = &cfg.text_config;
4445
4446 let cfg = ModelConfigMetadata {
4447 max_seq_len: cfg.max_position_embeddings,
4448 num_layers: cfg.num_hidden_layers,
4449 hidden_size: cfg.hidden_size,
4450 num_kv_heads: cfg.num_attention_heads,
4451 num_attn_heads: cfg.num_attention_heads,
4452 sliding_window: None,
4453 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
4454 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
4455 };
4456
4457 Ok(Box::new(cfg))
4458 }
4459
4460 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
4461 Some(vec![NonMappedSubModel::Vision])
4462 }
4463}