1use std::any::Any;
2use std::sync::Arc;
3use std::{fmt::Debug, str::FromStr};
4
5use anyhow::Result;
6use candle_core::{DType, Device, Tensor, D};
7use candle_nn::Conv2dConfig;
8use image::{ColorType, DynamicImage};
9use itertools::Itertools;
10use mistralrs_quant::ShardedVarBuilder;
11
12#[cfg(feature = "pyo3_macros")]
13use pyo3::pyclass;
14
15use regex::Regex;
16use serde::Deserialize;
17
18use self::minicpmo::{MiniCpmOConfig, MiniCpmOModel, MiniCpmOProcessor};
19
20use super::{DeviceMappedModelLoader, NonMappedSubModel, NormalLoadingMetadata};
21use crate::amoe::AnyMoeBaseModelMixin;
22use crate::device_map::DeviceMapper;
23use crate::layers::Conv3dConfig;
24use crate::paged_attention::{AttentionImplementation, ModelConfigLike, ModelConfigMetadata};
25use crate::pipeline::isq::IsqModelLoader;
26use crate::pipeline::loaders::AutoDeviceMapParams;
27use crate::pipeline::text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata};
28use crate::pipeline::{EitherCache, IsqModel, Processor, ProcessorCreator, VisionPromptPrefixer};
29use crate::utils::varbuilder_utils::DeviceForLoadTensor;
30use crate::vision_models::clip::ClipConfig;
31use crate::vision_models::gemma3::config::Gemma3Config;
32use crate::vision_models::gemma3::{Gemma3Model, Gemma3Processor};
33use crate::vision_models::idefics2::{Config as Idefics2Config, Idefics2};
34use crate::vision_models::idefics2_input_processor::Idefics2Processor;
35use crate::vision_models::idefics3::{Idefics3Config, Idefics3Model, Idefics3Processor};
36use crate::vision_models::image_processor::ImagePreProcessor;
37use crate::vision_models::inputs_processor::Phi4MMProcessor;
38use crate::vision_models::llama4::{
39 self, Llama4Config, Llama4ImageProcessor, Llama4Model, Llama4Processor,
40};
41use crate::vision_models::llava::config::Config as LLaVAConfig;
42use crate::vision_models::llava15::Model as LLaVA;
43use crate::vision_models::llava_inputs_processor::{self, LLaVAProcessor};
44use crate::vision_models::llava_next::Model as LLaVANext;
45use crate::vision_models::llava_next_inputs_processor::{self, LLaVANextProcessor};
46use crate::vision_models::mistral3::{Mistral3Config, Mistral3Model, Mistral3Processor};
47use crate::vision_models::mllama::{MLlamaConfig, MLlamaModel, MLlamaProcessor};
48use crate::vision_models::phi3::{Config as Phi3Config, Model as Phi3, PHI3V_CLIP_CONFIG};
49use crate::vision_models::phi3_inputs_processor::Phi3Processor;
50use crate::vision_models::phi4::{Phi4MMConfig, Phi4MMModel, PHI4_MM_VISION_CFG};
51use crate::vision_models::preprocessor_config::PreProcessorConfig;
52use crate::vision_models::processor_config::ProcessorConfig;
53use crate::vision_models::qwen2_5_vl::{
54 Config as Qwen2_5VLConfig, Qwen2_5VLModel, Qwen2_5VLProcessor,
55};
56use crate::vision_models::qwen2vl::{Config as Qwen2VLConfig, Qwen2VLModel, Qwen2VLProcessor};
57use crate::vision_models::{minicpmo, phi4};
58
59pub trait VisionModel: IsqModel + AnyMoeBaseModelMixin {
60 #[allow(clippy::too_many_arguments)]
62 fn forward(
63 &self,
64 input_ids: &Tensor,
65 pixel_values: Option<Tensor>,
66 seqlen_offsets: &[usize],
67 context_lens: Vec<(usize, usize)>,
68 position_ids: Vec<usize>,
69 model_specific_args: Box<dyn Any>, metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
71 flash_params: &FlashParams,
72 ) -> candle_core::Result<Tensor>;
73 fn device(&self) -> &Device;
74 fn cache(&self) -> &EitherCache;
75 fn cache_mut(&mut self) -> &mut EitherCache;
76 fn max_seq_len(&self) -> usize;
77 fn has_conv2d(&self) -> bool;
78 fn config(&self) -> &ModelConfigMetadata;
79 fn default_model_specific_args(&self, input_ids: &Tensor) -> Box<dyn Any>;
81}
82
83pub trait VisionModelLoader: IsqModelLoader + Send + Sync + DeviceMappedModelLoader {
84 fn load(
85 &self,
86 config: &str,
87 use_flash_attn: bool,
88 vb: ShardedVarBuilder,
89 normal_loading_metadata: NormalLoadingMetadata,
90 attention_mechanism: AttentionImplementation,
91 ) -> Result<Box<dyn VisionModel + Send + Sync>>;
92 fn is_gptx(&self) -> bool;
93 fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>>;
94 fn get_processor(
95 &self,
96 model_config: &str,
97 processor_config: Option<ProcessorConfig>,
98 preprocessor_config: PreProcessorConfig,
99 max_edge: Option<u32>,
100 ) -> Arc<dyn Processor + Send + Sync>;
101 fn supports_paged_attention(&self) -> bool;
102 fn prefixer(&self) -> Arc<dyn VisionPromptPrefixer>;
103 fn get_device_for_tensor(
104 &self,
105 config: &str,
106 _mapper: &dyn DeviceMapper,
107 loading_isq: bool,
108 ) -> Result<Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>> {
109 if loading_isq {
110 Ok(Arc::new(|_| DeviceForLoadTensor::Base))
111 } else {
112 let re = Regex::new(r"\.layers\.(\d+)\.").unwrap();
113 let num_layers = self.model_config(config)?.num_layers();
114 let closure = move |name: String| {
115 if let Some(captures) = re.captures(&name) {
116 captures
117 .get(1)
118 .and_then(|m| m.as_str().parse::<usize>().ok())
119 .map(|l| l.min(num_layers))
120 .map(DeviceForLoadTensor::Idx)
121 .unwrap_or(DeviceForLoadTensor::Base)
122 } else {
123 DeviceForLoadTensor::Base
124 }
125 };
126
127 Ok(Arc::new(closure))
128 }
129 }
130}
131
132#[cfg_attr(feature = "pyo3_macros", pyclass(eq, eq_int))]
133#[derive(Clone, Debug, Deserialize, PartialEq)]
134pub enum VisionLoaderType {
136 #[serde(rename = "phi3v")]
137 Phi3V,
138 #[serde(rename = "idefics2")]
139 Idefics2,
140 #[serde(rename = "llava_next")]
141 LLaVANext,
142 #[serde(rename = "llava")]
143 LLaVA,
144 #[serde(rename = "vllama")]
145 VLlama,
146 #[serde(rename = "qwen2vl")]
147 Qwen2VL,
148 #[serde(rename = "idefics3")]
149 Idefics3,
150 #[serde(rename = "minicpmo")]
151 MiniCpmO,
152 #[serde(rename = "phi4mm")]
153 Phi4MM,
154 #[serde(rename = "qwen2_5vl")]
155 Qwen2_5VL,
156 #[serde(rename = "gemma3")]
157 Gemma3,
158 #[serde(rename = "mistral3")]
159 Mistral3,
160 #[serde(rename = "llama4")]
161 Llama4,
162}
163
164impl FromStr for VisionLoaderType {
165 type Err = String;
166 fn from_str(s: &str) -> Result<Self, Self::Err> {
167 match s {
168 "phi3v" => Ok(Self::Phi3V),
169 "idefics2" => Ok(Self::Idefics2),
170 "llava_next" => Ok(Self::LLaVANext),
171 "llava" => Ok(Self::LLaVA),
172 "vllama" => Ok(Self::VLlama),
173 "qwen2vl" => Ok(Self::Qwen2VL),
174 "idefics3" => Ok(Self::Idefics3),
175 "minicpmo" => Ok(Self::MiniCpmO),
176 "phi4mm" => Ok(Self::Phi4MM),
177 "qwen2_5vl" => Ok(Self::Qwen2_5VL),
178 "gemma3" => Ok(Self::Gemma3),
179 "mistral3" => Ok(Self::Mistral3),
180 "llama4" => Ok(Self::Llama4),
181 a => Err(format!("Unknown architecture `{a}`. Possible architectures: `phi3v`, `idefics2`, `llava_next`, `llava`, `vllama`, `qwen2vl`, `idefics3`, `minicpmo`, `phi4mm`, `qwen2_5vl`, `gemma3`, `mistral3`, `llama4`.")),
182 }
183 }
184}
185
186macro_rules! bias_if {
187 ($cond:expr, $size:expr) => {
188 if $cond {
189 $size
190 } else {
191 0
192 }
193 };
194}
195
196fn get_clip_vit_num_elems(cfg: &ClipConfig) -> usize {
197 let pre_layer_norm = cfg.hidden_size;
198 let final_layer_norm = cfg.hidden_size;
199
200 let num_patches = (cfg.image_size / cfg.patch_size).pow(2);
201 let num_positions = num_patches + 1;
202
203 let class_embedding = cfg.hidden_size;
204
205 let position_ids = num_positions;
206 let position_embedding = num_positions * cfg.hidden_size;
207
208 let conv2dconfig = Conv2dConfig {
209 stride: cfg.patch_size,
210 ..Default::default()
211 };
212 let patch_embedding =
213 cfg.num_channels * cfg.hidden_size / conv2dconfig.groups * cfg.patch_size * cfg.patch_size;
214
215 let encoder_layer_elems = {
216 let layer_norm1 = cfg.hidden_size;
217 let layer_norm2 = cfg.hidden_size;
218
219 let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
220 let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
221 let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
222 let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
223
224 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
225 let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
226
227 layer_norm1 + layer_norm2 + q_proj + k_proj + v_proj + o_proj + fc1 + fc2
228 };
229
230 pre_layer_norm
231 + final_layer_norm
232 + class_embedding
233 + position_ids
234 + position_embedding
235 + patch_embedding
236 + cfg.num_hidden_layers * encoder_layer_elems
237}
238
239pub struct Phi3VLoader;
245
246pub struct Phi3VPrefixer;
247
248impl VisionPromptPrefixer for Phi3VPrefixer {
249 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
250 format!(
252 "{}{prompt}",
253 image_indexes
254 .into_iter()
255 .map(|image_index| format!("<|image_{}|>", image_index + 1))
256 .join("")
257 )
258 }
259}
260
261impl VisionModelLoader for Phi3VLoader {
262 fn load(
263 &self,
264 config: &str,
265 use_flash_attn: bool,
266 vb: ShardedVarBuilder,
267 normal_loading_metadata: NormalLoadingMetadata,
268 attention_mechanism: AttentionImplementation,
269 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
270 let mut config: Phi3Config = serde_json::from_str(config)?;
271 config.use_flash_attn = use_flash_attn;
272 Ok(Box::new(Phi3::new(
273 &config,
274 vb,
275 self.is_gptx(),
276 normal_loading_metadata,
277 attention_mechanism,
278 )?))
279 }
280 fn is_gptx(&self) -> bool {
281 true
282 }
283 fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
284 let mut config: Phi3Config = serde_json::from_str(config)?;
285 config.use_flash_attn = use_flash_attn;
286 Ok(Box::new(config))
287 }
288 fn get_processor(
289 &self,
290 _model_config: &str,
291 processor_config: Option<ProcessorConfig>,
292 preprocessor_config: PreProcessorConfig,
293 _max_edge: Option<u32>,
294 ) -> Arc<dyn Processor + Send + Sync> {
295 Phi3Processor::new_processor(processor_config, preprocessor_config)
296 }
297 fn supports_paged_attention(&self) -> bool {
298 true
299 }
300 fn prefixer(&self) -> Arc<dyn VisionPromptPrefixer> {
301 Arc::new(Phi3VPrefixer)
302 }
303}
304
305impl IsqModelLoader for Phi3VLoader {
306 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
307 Ok(vec![
308 Regex::new(r"lm_head\.(weight|bias)$")?,
309 Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?,
311 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
312 Regex::new(r"layers\.(\d+)\.mlp\.gate_up_proj\.(weight|bias)$")?,
314 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
315 ])
316 }
317}
318
319impl DeviceMappedModelLoader for Phi3VLoader {
320 fn mapped_max_act_size_elems(
321 &self,
322 config: &str,
323 params: &AutoDeviceMapParams,
324 _prompt_chunksize: usize,
325 ) -> Result<usize> {
326 let AutoDeviceMapParams::Vision {
328 max_seq_len,
329 max_batch_size,
330 max_image_shape: _,
331 max_num_images,
332 } = params
333 else {
334 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
335 };
336
337 let cfg: Phi3Config = serde_json::from_str(config)?;
338
339 let vcfg = &PHI3V_CLIP_CONFIG;
340
341 let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
342 let img_seq_len = (num_patches + 1) * max_num_images;
343
344 let max_text_attn = {
345 let max_seq_len = img_seq_len + max_seq_len;
347 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
348 };
349
350 Ok(max_text_attn)
351 }
352
353 fn non_mapped_max_act_size_elems(
354 &self,
355 config: &str,
356 params: &AutoDeviceMapParams,
357 ) -> Result<usize> {
358 let AutoDeviceMapParams::Vision {
360 max_seq_len: _,
361 max_batch_size,
362 max_image_shape: _,
363 max_num_images,
364 } = params
365 else {
366 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
367 };
368
369 let cfg: Phi3Config = serde_json::from_str(config)?;
370
371 let vcfg = &PHI3V_CLIP_CONFIG;
372
373 let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
374 let img_seq_len = num_patches + 1;
375
376 let max_vision_attn = {
377 (max_batch_size * max_num_images) * cfg.num_attention_heads * img_seq_len * img_seq_len
378 };
379
380 Ok(max_vision_attn)
381 }
382
383 fn non_mapped_size_in_bytes(
384 &self,
385 config: &str,
386 dtype: DType,
387 weight_pack_factor: usize,
388 ) -> Result<usize> {
389 let cfg: Phi3Config = serde_json::from_str(config)?;
390 let elems = {
391 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
392 let lm_head = if !cfg.tie_word_embeddings {
393 cfg.hidden_size * cfg.vocab_size
394 } else {
395 0
396 };
397 let norm = cfg.hidden_size;
398
399 let image_embed = {
400 let projection_cls = cfg
401 .embd_layer
402 .projection_cls
403 .clone()
404 .unwrap_or("linear".to_string());
405 let with_learnable_separator =
406 cfg.embd_layer.with_learnable_separator.unwrap_or(false);
407 let use_hd_transform = cfg.embd_layer.use_hd_transform.unwrap_or(false);
408 let image_dim_out = cfg.img_processor.image_dim_out;
409
410 let proj = match (projection_cls.as_str(), use_hd_transform) {
411 ("linear", _) => image_dim_out * cfg.hidden_size + cfg.hidden_size,
412 ("mlp", true) => {
413 let a = (image_dim_out * 4) * cfg.hidden_size + cfg.hidden_size;
414 let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
415 a + b
416 }
417 ("mlp", false) => {
418 let a = image_dim_out * cfg.hidden_size + cfg.hidden_size;
419 let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
420 a + b
421 }
422 _ => {
423 anyhow::bail!("projection_cls=`{projection_cls}` not implemented.");
424 }
425 };
426
427 let (glb_gn, sub_gn) = if with_learnable_separator {
428 let glb_gn = image_dim_out * 4;
429 let sub_gn = image_dim_out * 4;
430 (glb_gn, sub_gn)
431 } else {
432 (0, 0)
433 };
434
435 let clip_vit = get_clip_vit_num_elems(&PHI3V_CLIP_CONFIG);
436
437 proj + glb_gn + sub_gn + clip_vit
438 };
439
440 embed_tokens + lm_head + norm + image_embed
441 };
442
443 Ok(elems * dtype.size_in_bytes())
444 }
445
446 fn layer_sizes_in_bytes(
447 &self,
448 config: &str,
449 dtype: DType,
450 weight_pack_factor: usize,
451 ) -> Result<Vec<usize>> {
452 let cfg: Phi3Config = serde_json::from_str(config)?;
453 let per_layer_elems = {
454 let input_layernorm = cfg.hidden_size;
455 let post_attention_layernorm = cfg.hidden_size;
456
457 let size_in = cfg.hidden_size;
458 let head_dim = cfg.head_dim();
459 let op_size = head_dim * head_dim + 2 * cfg.num_key_value_heads * head_dim;
460 let qkv_proj = size_in * op_size / weight_pack_factor;
461 let o_proj = (cfg.num_attention_heads * head_dim) * size_in / weight_pack_factor;
462
463 let h_size = cfg.hidden_size;
464 let i_size = cfg.intermediate_size;
465 let gate_up_proj = h_size * (2 * i_size) / weight_pack_factor;
466 let down_proj = h_size * i_size / weight_pack_factor;
467
468 input_layernorm
469 + post_attention_layernorm
470 + qkv_proj
471 + o_proj
472 + gate_up_proj
473 + down_proj
474 };
475 Ok(vec![
476 per_layer_elems * dtype.size_in_bytes();
477 cfg.num_hidden_layers
478 ])
479 }
480
481 fn num_layers(&self, config: &str) -> Result<usize> {
482 let cfg: Phi3Config = serde_json::from_str(config)?;
483 Ok(cfg.num_hidden_layers)
484 }
485
486 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
487 let cfg: Phi3Config = serde_json::from_str(config)?;
488
489 let cfg = ModelConfigMetadata {
490 max_seq_len: cfg.max_position_embeddings,
491 num_layers: cfg.num_hidden_layers,
492 hidden_size: cfg.hidden_size,
493 num_kv_heads: cfg.num_key_value_heads,
494 num_attn_heads: cfg.num_attention_heads,
495 sliding_window: cfg.sliding_window,
496 k_head_dim: cfg.head_dim(),
497 v_head_dim: cfg.head_dim(),
498 };
499
500 Ok(Box::new(cfg))
501 }
502
503 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
504 Some(vec![NonMappedSubModel::Vision])
505 }
506}
507
508pub struct Idefics2Loader;
514
515pub struct Idefics2Prefixer;
516
517impl VisionPromptPrefixer for Idefics2Prefixer {
518 fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
519 prompt.to_string()
521 }
522}
523
524impl VisionModelLoader for Idefics2Loader {
525 fn load(
526 &self,
527 config: &str,
528 use_flash_attn: bool,
529 vb: ShardedVarBuilder,
530 normal_loading_metadata: NormalLoadingMetadata,
531 attention_mechanism: AttentionImplementation,
532 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
533 let mut config: Idefics2Config = serde_json::from_str(config)?;
534 config.text_config.use_flash_attn = use_flash_attn;
535 Ok(Box::new(Idefics2::new(
536 &config,
537 vb,
538 self.is_gptx(),
539 normal_loading_metadata,
540 attention_mechanism,
541 )?))
542 }
543 fn is_gptx(&self) -> bool {
544 true
545 }
546 fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
547 let mut config: Idefics2Config = serde_json::from_str(config)?;
548 config.text_config.use_flash_attn = use_flash_attn;
549 Ok(Box::new(config))
550 }
551 fn get_processor(
552 &self,
553 _model_config: &str,
554 processor_config: Option<ProcessorConfig>,
555 preprocessor_config: PreProcessorConfig,
556 max_edge: Option<u32>,
557 ) -> Arc<dyn Processor + Send + Sync> {
558 Arc::new(Idefics2Processor::new(
559 processor_config.unwrap(),
560 preprocessor_config,
561 max_edge,
562 ))
563 }
564 fn supports_paged_attention(&self) -> bool {
565 true
566 }
567 fn prefixer(&self) -> Arc<dyn VisionPromptPrefixer> {
568 Arc::new(Idefics2Prefixer)
569 }
570}
571
572impl IsqModelLoader for Idefics2Loader {
573 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
574 Ok(vec![
575 Regex::new(r"lm_head\.(weight|bias)$")?,
576 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
578 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
579 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
580 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
581 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
583 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
584 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
585 ])
586 }
587}
588
589impl DeviceMappedModelLoader for Idefics2Loader {
590 fn mapped_max_act_size_elems(
591 &self,
592 config: &str,
593 params: &AutoDeviceMapParams,
594 _prompt_chunksize: usize,
595 ) -> Result<usize> {
596 let AutoDeviceMapParams::Vision {
597 max_seq_len,
598 max_batch_size,
599 max_image_shape: _,
600 max_num_images,
601 } = params
602 else {
603 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
604 };
605
606 let cfg: Idefics2Config = serde_json::from_str(config)?;
607
608 let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
609 let img_seq_len = (num_patches + 1) * max_num_images;
610
611 let max_text_attn = {
612 let max_seq_len = img_seq_len + max_seq_len;
614 max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len
615 };
616
617 Ok(max_text_attn)
618 }
619
620 fn non_mapped_max_act_size_elems(
621 &self,
622 config: &str,
623 params: &AutoDeviceMapParams,
624 ) -> Result<usize> {
625 let AutoDeviceMapParams::Vision {
626 max_seq_len: _,
627 max_batch_size,
628 max_image_shape: _,
629 max_num_images,
630 } = params
631 else {
632 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
633 };
634
635 let cfg: Idefics2Config = serde_json::from_str(config)?;
636
637 let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
638 let img_seq_len = num_patches + 1;
639
640 let max_vision_attn = {
641 let images_factor = 5;
643
644 (max_batch_size * images_factor * max_num_images)
645 * cfg.vision_config.num_attention_heads
646 * img_seq_len
647 * img_seq_len
648 };
649
650 Ok(max_vision_attn)
651 }
652
653 fn non_mapped_size_in_bytes(
654 &self,
655 config: &str,
656 dtype: DType,
657 weight_pack_factor: usize,
658 ) -> Result<usize> {
659 let cfg: Idefics2Config = serde_json::from_str(config)?;
660 let text_elems = {
661 let tie_word_embeddings = cfg.tie_word_embeddings;
662 let cfg = &cfg.text_config;
663
664 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
665 let lm_head = if !tie_word_embeddings {
666 cfg.hidden_size * cfg.vocab_size
667 } else {
668 0
669 };
670 let norm = cfg.hidden_size;
671 embed_tokens + lm_head + norm
672 };
673
674 let connector_elems = {
675 let tcfg = &cfg.text_config;
676 let vcfg = &cfg.vision_config;
677 let gate_proj = vcfg.hidden_size * tcfg.intermediate_size;
678 let up_proj = vcfg.hidden_size * tcfg.intermediate_size;
679 let down_proj = tcfg.intermediate_size * tcfg.hidden_size;
680
681 let perceiver_elems = {
682 let tcfg = &cfg.text_config;
683 let pcfg = &cfg.perceiver_config;
684
685 let n_latents = pcfg.resampler_n_latents;
686 let hidden_size = tcfg.hidden_size;
687 let depth = pcfg.resampler_depth;
688
689 let norm = tcfg.hidden_size;
690 let latents = n_latents * hidden_size;
691
692 let layer_elems = {
693 let input_latents_norm = hidden_size;
694 let input_context_norm = hidden_size;
695 let post_attn_norm = hidden_size;
696
697 let num_heads = pcfg.resampler_n_heads;
698 let head_dim = pcfg.resampler_head_dim;
699 let num_key_value_heads = pcfg.num_key_value_heads;
700
701 let q_proj = hidden_size * num_heads * head_dim;
702 let k_proj = hidden_size * num_key_value_heads * head_dim;
703 let v_proj = hidden_size * num_key_value_heads * head_dim;
704 let o_proj = num_heads * head_dim * hidden_size;
705
706 let gate_proj = hidden_size * hidden_size * 4;
707 let up_proj = hidden_size * hidden_size * 4;
708 let down_proj = hidden_size * 4 * hidden_size;
709
710 input_latents_norm
711 + input_context_norm
712 + post_attn_norm
713 + q_proj
714 + k_proj
715 + v_proj
716 + o_proj
717 + gate_proj
718 + up_proj
719 + down_proj
720 };
721
722 norm + latents + layer_elems * depth
723 };
724
725 gate_proj + up_proj + down_proj + perceiver_elems
726 };
727
728 let vision_transformer = {
729 let cfg = &cfg.vision_config;
730
731 let post_layernorm = cfg.hidden_size;
732
733 let conv_config = Conv2dConfig {
734 stride: cfg.patch_size,
735 ..Default::default()
736 };
737 let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
738 * cfg.patch_size
739 * cfg.patch_size;
740
741 let num_patches_per_side = cfg.image_size / cfg.patch_size;
742 let num_patches = num_patches_per_side.pow(2);
743 let position_embedding = num_patches * cfg.hidden_size;
744
745 let layer_elems = {
746 let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
747 let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
748
749 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
750 let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
751
752 let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
753 let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
754 let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
755 let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
756
757 layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
758 };
759
760 post_layernorm + patch_embedding + position_embedding + layer_elems
761 };
762
763 let elems = text_elems + connector_elems + vision_transformer;
764
765 Ok(elems * dtype.size_in_bytes())
766 }
767
768 fn layer_sizes_in_bytes(
769 &self,
770 config: &str,
771 dtype: DType,
772 weight_pack_factor: usize,
773 ) -> Result<Vec<usize>> {
774 let cfg: Idefics2Config = serde_json::from_str(config)?;
775 let cfg = cfg.text_config;
776 let per_layer_elems = {
777 let input_layernorm = cfg.hidden_size;
778 let post_attention_layernorm = cfg.hidden_size;
779
780 let size_in = cfg.hidden_size;
781 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
782 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
783 let q_proj = size_in * size_q / weight_pack_factor;
784 let k_proj = size_in * size_kv / weight_pack_factor;
785 let v_proj = size_in * size_kv / weight_pack_factor;
786 let o_proj = size_q * size_in / weight_pack_factor;
787
788 let h_size = cfg.hidden_size;
789 let i_size = cfg.intermediate_size;
790 let gate_proj = h_size * i_size / weight_pack_factor;
791 let up_proj = h_size * i_size / weight_pack_factor;
792 let down_proj = i_size * h_size / weight_pack_factor;
793
794 input_layernorm
795 + post_attention_layernorm
796 + q_proj
797 + k_proj
798 + v_proj
799 + o_proj
800 + gate_proj
801 + up_proj
802 + down_proj
803 };
804 Ok(vec![
805 per_layer_elems * dtype.size_in_bytes();
806 cfg.num_hidden_layers
807 ])
808 }
809
810 fn num_layers(&self, config: &str) -> Result<usize> {
811 let cfg: Idefics2Config = serde_json::from_str(config)?;
812 Ok(cfg.text_config.num_hidden_layers)
813 }
814 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
815 let cfg: Idefics2Config = serde_json::from_str(config)?;
816 let cfg = &cfg.text_config;
817
818 let cfg = ModelConfigMetadata {
819 max_seq_len: cfg.max_position_embeddings,
820 num_layers: cfg.num_hidden_layers,
821 hidden_size: cfg.hidden_size,
822 num_kv_heads: cfg.num_key_value_heads,
823 num_attn_heads: cfg.num_attention_heads,
824 sliding_window: cfg.sliding_window,
825 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
826 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
827 };
828
829 Ok(Box::new(cfg))
830 }
831
832 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
833 Some(vec![NonMappedSubModel::Vision])
834 }
835}
836
837pub struct LLaVANextLoader;
843
844pub struct LLaVANextPrefixer;
845
846impl VisionPromptPrefixer for LLaVANextPrefixer {
847 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
848 format!("{}{prompt}", "<image>".repeat(image_indexes.len()))
849 }
850}
851
852impl VisionModelLoader for LLaVANextLoader {
853 fn load(
854 &self,
855 config: &str,
856 use_flash_attn: bool,
857 vb: ShardedVarBuilder,
858 normal_loading_metadata: NormalLoadingMetadata,
859 attention_mechanism: AttentionImplementation,
860 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
861 let mut config: LLaVAConfig = serde_json::from_str(config)?;
862 config.use_flash_attn = use_flash_attn;
863 Ok(Box::new(LLaVANext::new(
864 &config,
865 vb,
866 self.is_gptx(),
867 normal_loading_metadata,
868 attention_mechanism,
869 )?))
870 }
871 fn is_gptx(&self) -> bool {
872 false
873 }
874 fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
875 let mut config: LLaVAConfig = serde_json::from_str(config)?;
876 config.use_flash_attn = use_flash_attn;
877 Ok(Box::new(config))
878 }
879 fn get_processor(
880 &self,
881 model_config: &str,
882 _processor_config: Option<ProcessorConfig>,
883 _preprocessor_config: PreProcessorConfig,
884 _max_edge: Option<u32>,
885 ) -> Arc<dyn Processor + Send + Sync> {
886 Arc::new(LLaVANextProcessor::new(model_config))
887 }
888 fn supports_paged_attention(&self) -> bool {
889 true
890 }
891 fn prefixer(&self) -> Arc<dyn VisionPromptPrefixer> {
892 Arc::new(LLaVANextPrefixer)
893 }
894}
895
896impl IsqModelLoader for LLaVANextLoader {
897 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
898 Ok(vec![
899 Regex::new(r"lm_head\.(weight|bias)$")?,
900 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
902 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
903 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
904 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
905 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
907 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
908 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
909 ])
910 }
911}
912
913impl DeviceMappedModelLoader for LLaVANextLoader {
914 fn mapped_max_act_size_elems(
915 &self,
916 config: &str,
917 params: &AutoDeviceMapParams,
918 _prompt_chunksize: usize,
919 ) -> Result<usize> {
920 let AutoDeviceMapParams::Vision {
921 max_seq_len,
922 max_batch_size,
923 max_image_shape,
924 max_num_images,
925 } = params
926 else {
927 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
928 };
929
930 let config: LLaVAConfig = serde_json::from_str(config)?;
931
932 #[allow(clippy::cast_possible_truncation)]
933 let img_seq_len =
934 llava_next_inputs_processor::LLaVANextInputProcessor::get_num_image_tokens(
935 &config,
936 (max_image_shape.0 as u32, max_image_shape.1 as u32),
937 );
938 let img_seq_len = img_seq_len * max_num_images;
939
940 let max_text_attn = {
941 let cfg = &config.text_config;
942 let max_seq_len = img_seq_len + max_seq_len;
944
945 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
946 };
947
948 Ok(max_text_attn)
949 }
950
951 fn non_mapped_max_act_size_elems(
952 &self,
953 config: &str,
954 params: &AutoDeviceMapParams,
955 ) -> Result<usize> {
956 let AutoDeviceMapParams::Vision {
957 max_seq_len: _,
958 max_batch_size,
959 max_image_shape,
960 max_num_images,
961 } = params
962 else {
963 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
964 };
965
966 let config: LLaVAConfig = serde_json::from_str(config)?;
967
968 #[allow(clippy::cast_possible_truncation)]
969 let img_seq_len =
970 llava_next_inputs_processor::LLaVANextInputProcessor::get_num_image_tokens(
971 &config,
972 (max_image_shape.0 as u32, max_image_shape.1 as u32),
973 );
974
975 let max_vision_attn = {
976 (max_batch_size * max_num_images)
977 * config.vision_config.num_attention_heads
978 * img_seq_len
979 * img_seq_len
980 };
981
982 Ok(max_vision_attn)
983 }
984
985 fn non_mapped_size_in_bytes(
986 &self,
987 config: &str,
988 dtype: DType,
989 weight_pack_factor: usize,
990 ) -> Result<usize> {
991 let cfg: LLaVAConfig = serde_json::from_str(config)?;
992 let text_elems = {
993 let cfg = &cfg.text_config;
994 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
995 let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
996 let norm = cfg.hidden_size;
997 embed_tokens + lm_head + norm
998 };
999
1000 let image_newline = cfg.text_config.hidden_size;
1001 let mmproj = {
1002 let linear_1 = cfg.vision_config.hidden_size * cfg.text_config.hidden_size
1003 + cfg.text_config.hidden_size;
1004 let linear_2 = cfg.text_config.hidden_size * cfg.text_config.hidden_size
1005 + cfg.text_config.hidden_size;
1006
1007 linear_1 + linear_2
1008 };
1009 let vision_tower = get_clip_vit_num_elems(&cfg.to_clip_config());
1010
1011 let elems = text_elems + image_newline + mmproj + vision_tower;
1012 Ok(elems * dtype.size_in_bytes())
1013 }
1014
1015 fn layer_sizes_in_bytes(
1016 &self,
1017 config: &str,
1018 dtype: DType,
1019 weight_pack_factor: usize,
1020 ) -> Result<Vec<usize>> {
1021 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1022 let per_layer_elems = {
1023 let cfg = &cfg.text_config;
1024 let input_layernorm = cfg.hidden_size;
1025 let post_attention_layernorm = cfg.hidden_size;
1026
1027 let size_in = cfg.hidden_size;
1028 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1029 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1030 let q_proj = size_in * size_q / weight_pack_factor;
1031 let k_proj = size_in * size_kv / weight_pack_factor;
1032 let v_proj = size_in * size_kv / weight_pack_factor;
1033 let o_proj = size_q * size_in / weight_pack_factor;
1034
1035 let h_size = cfg.hidden_size;
1036 let i_size = cfg.intermediate_size;
1037 let gate_proj = h_size * i_size / weight_pack_factor;
1038 let up_proj = h_size * i_size / weight_pack_factor;
1039 let down_proj = i_size * h_size / weight_pack_factor;
1040
1041 input_layernorm
1042 + post_attention_layernorm
1043 + q_proj
1044 + k_proj
1045 + v_proj
1046 + o_proj
1047 + gate_proj
1048 + up_proj
1049 + down_proj
1050 };
1051 Ok(vec![
1052 per_layer_elems * dtype.size_in_bytes();
1053 cfg.text_config.num_hidden_layers
1054 ])
1055 }
1056
1057 fn num_layers(&self, config: &str) -> Result<usize> {
1058 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1059 Ok(cfg.text_config.num_hidden_layers)
1060 }
1061
1062 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1063 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1064 let cfg = &cfg.text_config;
1065
1066 let cfg = ModelConfigMetadata {
1067 max_seq_len: cfg.max_position_embeddings,
1068 num_layers: cfg.num_hidden_layers,
1069 hidden_size: cfg.hidden_size,
1070 num_kv_heads: cfg.num_key_value_heads,
1071 num_attn_heads: cfg.num_attention_heads,
1072 sliding_window: cfg.sliding_window,
1073 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1074 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1075 };
1076
1077 Ok(Box::new(cfg))
1078 }
1079
1080 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1081 Some(vec![NonMappedSubModel::Vision])
1082 }
1083}
1084
1085pub struct LLaVALoader;
1091
1092pub struct LLaVAPrefixer;
1093
1094impl VisionPromptPrefixer for LLaVAPrefixer {
1095 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
1096 format!("{}{prompt}", "<image>".repeat(image_indexes.len()))
1097 }
1098}
1099
1100impl VisionModelLoader for LLaVALoader {
1101 fn load(
1102 &self,
1103 config: &str,
1104 use_flash_attn: bool,
1105 vb: ShardedVarBuilder,
1106 normal_loading_metadata: NormalLoadingMetadata,
1107 attention_mechanism: AttentionImplementation,
1108 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
1109 let mut config: LLaVAConfig = serde_json::from_str(config)?;
1110 config.use_flash_attn = use_flash_attn;
1111 Ok(Box::new(LLaVA::new(
1112 &config,
1113 vb,
1114 self.is_gptx(),
1115 normal_loading_metadata,
1116 attention_mechanism,
1117 )?))
1118 }
1119 fn is_gptx(&self) -> bool {
1120 false
1121 }
1122 fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
1123 let mut config: LLaVAConfig = serde_json::from_str(config)?;
1124 config.use_flash_attn = use_flash_attn;
1125 Ok(Box::new(config))
1126 }
1127 fn get_processor(
1128 &self,
1129 model_config: &str,
1130 _processor_config: Option<ProcessorConfig>,
1131 _preprocessor_config: PreProcessorConfig,
1132 _max_edge: Option<u32>,
1133 ) -> Arc<dyn Processor + Send + Sync> {
1134 Arc::new(LLaVAProcessor::new(model_config))
1135 }
1136 fn supports_paged_attention(&self) -> bool {
1137 true
1138 }
1139 fn prefixer(&self) -> Arc<dyn VisionPromptPrefixer> {
1140 Arc::new(LLaVAPrefixer)
1141 }
1142}
1143
1144impl IsqModelLoader for LLaVALoader {
1145 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1146 Ok(vec![
1147 Regex::new(r"lm_head\.(weight|bias)$")?,
1148 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1150 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1151 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1152 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1153 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1155 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1156 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1157 ])
1158 }
1159}
1160
1161impl DeviceMappedModelLoader for LLaVALoader {
1162 fn mapped_max_act_size_elems(
1163 &self,
1164 config: &str,
1165 params: &AutoDeviceMapParams,
1166 _prompt_chunksize: usize,
1167 ) -> Result<usize> {
1168 let AutoDeviceMapParams::Vision {
1169 max_seq_len,
1170 max_batch_size,
1171 max_image_shape: _,
1172 max_num_images,
1173 } = params
1174 else {
1175 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1176 };
1177
1178 let config: LLaVAConfig = serde_json::from_str(config)?;
1179
1180 let img_seq_len =
1181 llava_inputs_processor::LLaVAInputProcessor::get_num_image_tokens(&config);
1182 let img_seq_len = img_seq_len * max_num_images;
1183
1184 let max_text_attn = {
1185 let cfg = &config.text_config;
1186 let max_seq_len = img_seq_len + max_seq_len;
1188
1189 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
1190 };
1191
1192 Ok(max_text_attn)
1193 }
1194
1195 fn non_mapped_max_act_size_elems(
1196 &self,
1197 config: &str,
1198 params: &AutoDeviceMapParams,
1199 ) -> Result<usize> {
1200 let AutoDeviceMapParams::Vision {
1201 max_seq_len: _,
1202 max_batch_size,
1203 max_image_shape: _,
1204 max_num_images,
1205 } = params
1206 else {
1207 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1208 };
1209
1210 let config: LLaVAConfig = serde_json::from_str(config)?;
1211
1212 let img_seq_len =
1213 llava_inputs_processor::LLaVAInputProcessor::get_num_image_tokens(&config);
1214
1215 let max_vision_attn = {
1216 (max_batch_size * max_num_images)
1217 * config.vision_config.num_attention_heads
1218 * img_seq_len
1219 * img_seq_len
1220 };
1221
1222 Ok(max_vision_attn)
1223 }
1224
1225 fn non_mapped_size_in_bytes(
1226 &self,
1227 config: &str,
1228 dtype: DType,
1229 weight_pack_factor: usize,
1230 ) -> Result<usize> {
1231 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1232 let text_elems = {
1233 let cfg = &cfg.text_config;
1234 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1235 let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1236 let norm = cfg.hidden_size;
1237 embed_tokens + lm_head + norm
1238 };
1239
1240 let image_newline = cfg.text_config.hidden_size;
1241 let mmproj = {
1242 let linear_1 = cfg.vision_config.hidden_size * cfg.text_config.hidden_size
1243 + cfg.text_config.hidden_size;
1244 let linear_2 = cfg.text_config.hidden_size * cfg.text_config.hidden_size
1245 + cfg.text_config.hidden_size;
1246
1247 linear_1 + linear_2
1248 };
1249 let vision_tower = get_clip_vit_num_elems(&cfg.to_clip_config());
1250
1251 let elems = text_elems + image_newline + mmproj + vision_tower;
1252 Ok(elems * dtype.size_in_bytes())
1253 }
1254
1255 fn layer_sizes_in_bytes(
1256 &self,
1257 config: &str,
1258 dtype: DType,
1259 weight_pack_factor: usize,
1260 ) -> Result<Vec<usize>> {
1261 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1262 let per_layer_elems = {
1263 let cfg = &cfg.text_config;
1264 let input_layernorm = cfg.hidden_size;
1265 let post_attention_layernorm = cfg.hidden_size;
1266
1267 let size_in = cfg.hidden_size;
1268 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1269 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1270 let q_proj = size_in * size_q / weight_pack_factor;
1271 let k_proj = size_in * size_kv / weight_pack_factor;
1272 let v_proj = size_in * size_kv / weight_pack_factor;
1273 let o_proj = size_q * size_in / weight_pack_factor;
1274
1275 let h_size = cfg.hidden_size;
1276 let i_size = cfg.intermediate_size;
1277 let gate_proj = h_size * i_size / weight_pack_factor;
1278 let up_proj = h_size * i_size / weight_pack_factor;
1279 let down_proj = i_size * h_size / weight_pack_factor;
1280
1281 input_layernorm
1282 + post_attention_layernorm
1283 + q_proj
1284 + k_proj
1285 + v_proj
1286 + o_proj
1287 + gate_proj
1288 + up_proj
1289 + down_proj
1290 };
1291 Ok(vec![
1292 per_layer_elems * dtype.size_in_bytes();
1293 cfg.text_config.num_hidden_layers
1294 ])
1295 }
1296
1297 fn num_layers(&self, config: &str) -> Result<usize> {
1298 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1299 Ok(cfg.text_config.num_hidden_layers)
1300 }
1301
1302 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1303 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1304 let cfg = &cfg.text_config;
1305
1306 let cfg = ModelConfigMetadata {
1307 max_seq_len: cfg.max_position_embeddings,
1308 num_layers: cfg.num_hidden_layers,
1309 hidden_size: cfg.hidden_size,
1310 num_kv_heads: cfg.num_key_value_heads,
1311 num_attn_heads: cfg.num_attention_heads,
1312 sliding_window: cfg.sliding_window,
1313 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1314 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1315 };
1316
1317 Ok(Box::new(cfg))
1318 }
1319
1320 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1321 Some(vec![NonMappedSubModel::Vision])
1322 }
1323}
1324
1325pub struct VLlamaLoader;
1331
1332pub struct VLlamaPrefixer;
1333
1334impl VisionPromptPrefixer for VLlamaPrefixer {
1335 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
1336 format!("{}{prompt}", "<|image|>".repeat(image_indexes.len()))
1337 }
1338}
1339
1340impl VisionModelLoader for VLlamaLoader {
1341 fn load(
1342 &self,
1343 config: &str,
1344 use_flash_attn: bool,
1345 vb: ShardedVarBuilder,
1346 normal_loading_metadata: NormalLoadingMetadata,
1347 attention_mechanism: AttentionImplementation,
1348 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
1349 let mut config: MLlamaConfig = serde_json::from_str(config)?;
1350 config.text_config.use_flash_attn = use_flash_attn;
1351 Ok(Box::new(MLlamaModel::new(
1352 &config,
1353 vb,
1354 self.is_gptx(),
1355 normal_loading_metadata,
1356 attention_mechanism,
1357 )?))
1358 }
1359 fn is_gptx(&self) -> bool {
1360 true
1361 }
1362 fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
1363 let mut config: MLlamaConfig = serde_json::from_str(config)?;
1364 config.text_config.use_flash_attn = use_flash_attn;
1365 Ok(Box::new(config))
1366 }
1367 fn get_processor(
1368 &self,
1369 _model_config: &str,
1370 _processor_config: Option<ProcessorConfig>,
1371 _preprocessor_config: PreProcessorConfig,
1372 _max_edge: Option<u32>,
1373 ) -> Arc<dyn Processor + Send + Sync> {
1374 Arc::new(MLlamaProcessor::new())
1375 }
1376 fn supports_paged_attention(&self) -> bool {
1377 false
1378 }
1379 fn prefixer(&self) -> Arc<dyn VisionPromptPrefixer> {
1380 Arc::new(VLlamaPrefixer)
1381 }
1382}
1383
1384impl IsqModelLoader for VLlamaLoader {
1385 fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
1386 let config: MLlamaConfig = serde_json::from_str(config)?;
1387 let cross_attn_layers = &config.text_config.cross_attention_layers;
1388 let transformer_layers =
1389 (0..config.text_config.num_hidden_layers).filter(|i| !cross_attn_layers.contains(i));
1390 let mut text_regexes = Vec::new();
1391 for layer in transformer_layers {
1392 text_regexes.extend(vec![
1393 Regex::new(&format!(
1395 r"language_model.model.layers\.{layer}\.self_attn\.q_proj\.(weight|bias)$"
1396 ))?,
1397 Regex::new(&format!(
1398 r"language_model.model.layers\.{layer}\.self_attn\.k_proj\.(weight|bias)$"
1399 ))?,
1400 Regex::new(&format!(
1401 r"language_model.model.layers\.{layer}\.self_attn\.v_proj\.(weight|bias)$"
1402 ))?,
1403 Regex::new(&format!(
1404 r"language_model.model.layers\.{layer}\.self_attn\.o_proj\.(weight|bias)$"
1405 ))?,
1406 Regex::new(&format!(
1408 r"language_model.model.layers\.{layer}\.mlp\.gate_proj\.(weight|bias)$"
1409 ))?,
1410 Regex::new(&format!(
1411 r"language_model.model.layers\.{layer}\.mlp\.up_proj\.(weight|bias)$"
1412 ))?,
1413 Regex::new(&format!(
1414 r"language_model.model.layers\.{layer}\.mlp\.down_proj\.(weight|bias)$"
1415 ))?,
1416 ]);
1417 }
1418 let vision_regexes = vec![
1419 Regex::new(
1421 r"vision_model.transformer.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
1422 )?,
1423 Regex::new(
1424 r"vision_model.transformer.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$",
1425 )?,
1426 Regex::new(
1427 r"vision_model.transformer.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$",
1428 )?,
1429 Regex::new(
1430 r"vision_model.transformer.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$",
1431 )?,
1432 Regex::new(
1434 r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
1435 )?,
1436 Regex::new(
1437 r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$",
1438 )?,
1439 Regex::new(
1440 r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$",
1441 )?,
1442 Regex::new(
1443 r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$",
1444 )?,
1445 Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
1447 Regex::new(r"layers\.(\d+)\.mlp\.fc2\.(weight|bias)$")?,
1448 ];
1449
1450 Ok([text_regexes, vision_regexes].concat())
1451 }
1452}
1453
1454impl DeviceMappedModelLoader for VLlamaLoader {
1455 fn mapped_max_act_size_elems(
1456 &self,
1457 config: &str,
1458 params: &AutoDeviceMapParams,
1459 _prompt_chunksize: usize,
1460 ) -> Result<usize> {
1461 let AutoDeviceMapParams::Vision {
1462 max_seq_len,
1463 max_batch_size,
1464 max_image_shape: _,
1465 max_num_images,
1466 } = params
1467 else {
1468 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1469 };
1470
1471 let config: MLlamaConfig = serde_json::from_str(config)?;
1472
1473 let img_seq_len = {
1474 let cfg = &config.vision_config;
1475 let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
1476 let num_padding_patches = (8 - (num_patches as isize % 8)) % 8;
1477 cfg.max_num_tiles * (num_patches as isize + num_padding_patches) as usize
1478 };
1479 let img_seq_len = img_seq_len * max_num_images;
1480
1481 let max_cross_text_attn = {
1482 let cfg = &config.text_config;
1483 max_batch_size * cfg.num_attention_heads * img_seq_len * img_seq_len
1484 };
1485
1486 let max_self_text_attn = {
1487 let cfg = &config.text_config;
1488 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
1489 };
1490
1491 Ok(max_self_text_attn.max(max_cross_text_attn))
1492 }
1493
1494 fn non_mapped_max_act_size_elems(
1495 &self,
1496 config: &str,
1497 params: &AutoDeviceMapParams,
1498 ) -> Result<usize> {
1499 let AutoDeviceMapParams::Vision {
1500 max_seq_len: _,
1501 max_batch_size,
1502 max_image_shape: _,
1503 max_num_images,
1504 } = params
1505 else {
1506 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1507 };
1508
1509 let config: MLlamaConfig = serde_json::from_str(config)?;
1510
1511 let img_seq_len = {
1512 let cfg = &config.vision_config;
1513 let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
1514 let num_padding_patches = (8 - (num_patches as isize % 8)) % 8;
1515 cfg.max_num_tiles * (num_patches as isize + num_padding_patches) as usize
1516 };
1517 let max_vision_attn = {
1518 let cfg = &config.vision_config;
1519 (max_batch_size * max_num_images) * cfg.num_attention_heads * img_seq_len * img_seq_len
1520 };
1521
1522 Ok(max_vision_attn)
1523 }
1524
1525 fn non_mapped_size_in_bytes(
1526 &self,
1527 config: &str,
1528 dtype: DType,
1529 weight_pack_factor: usize,
1530 ) -> Result<usize> {
1531 let config: MLlamaConfig = serde_json::from_str(config)?;
1532 let text_elems = {
1533 let cfg = &config.text_config;
1534 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1535 let lm_head = if !cfg.tie_word_embeddings {
1536 cfg.hidden_size * cfg.vocab_size
1537 } else {
1538 0
1539 };
1540 let norm = cfg.hidden_size;
1541 embed_tokens + lm_head + norm
1542 };
1543
1544 let vision_elems = {
1545 let cfg = &config.vision_config;
1546
1547 let conv_cfg = Conv2dConfig {
1548 stride: cfg.patch_size,
1549 ..Default::default()
1550 };
1551 let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_cfg.groups
1552 * cfg.patch_size
1553 * cfg.patch_size;
1554
1555 let class_embedding = cfg.hidden_size;
1556
1557 let gated_positional_embedding = {
1558 let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
1559 let embedding = num_patches * cfg.hidden_size;
1560 let tile_embedding = (cfg.max_aspect_ratio_id() + 1)
1561 * (cfg.max_num_tiles * num_patches * cfg.hidden_size);
1562
1563 embedding + tile_embedding
1564 };
1565
1566 let pre_tile_positional_embedding =
1567 (cfg.max_aspect_ratio_id() + 1) * (cfg.max_num_tiles * cfg.hidden_size);
1568 let post_tile_positional_embedding =
1569 (cfg.max_aspect_ratio_id() + 1) * (cfg.max_num_tiles * cfg.hidden_size);
1570
1571 let layernorm_pre = cfg.hidden_size;
1572 let layernorm_post = cfg.hidden_size;
1573
1574 let encoder_layer = {
1575 let input_layernorm = cfg.hidden_size + cfg.hidden_size;
1576 let post_attention_layernorm = cfg.hidden_size + cfg.hidden_size;
1577
1578 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
1579 let q_proj =
1580 cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1581 let k_proj =
1582 cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1583 let v_proj =
1584 cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1585 let o_proj =
1586 cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1587
1588 let fc1 = (cfg.hidden_size * cfg.intermediate_size) / weight_pack_factor
1589 + cfg.intermediate_size;
1590 let fc2 = (cfg.intermediate_size * cfg.hidden_size) / weight_pack_factor
1591 + cfg.hidden_size;
1592
1593 input_layernorm
1594 + post_attention_layernorm
1595 + q_proj
1596 + k_proj
1597 + v_proj
1598 + o_proj
1599 + fc1
1600 + fc2
1601 };
1602
1603 patch_embedding
1604 + class_embedding
1605 + gated_positional_embedding
1606 + pre_tile_positional_embedding
1607 + post_tile_positional_embedding
1608 + layernorm_pre
1609 + layernorm_post
1610 + encoder_layer * (cfg.num_hidden_layers + cfg.num_global_layers)
1611 };
1612
1613 let elems = text_elems + vision_elems;
1614 Ok(elems * dtype.size_in_bytes())
1615 }
1616
1617 fn layer_sizes_in_bytes(
1618 &self,
1619 config: &str,
1620 dtype: DType,
1621 weight_pack_factor: usize,
1622 ) -> Result<Vec<usize>> {
1623 let config: MLlamaConfig = serde_json::from_str(config)?;
1624 let cfg = &config.text_config;
1625
1626 let mut layer_sizes = Vec::new();
1627
1628 for i in 0..cfg.num_hidden_layers {
1629 let weight_pack_factor = if cfg.cross_attention_layers.contains(&i) {
1630 1
1632 } else {
1633 weight_pack_factor
1634 };
1635
1636 let per_layer_elems = {
1637 let input_layernorm = cfg.hidden_size;
1638 let post_attention_layernorm = cfg.hidden_size;
1639
1640 let size_in = cfg.hidden_size;
1641 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1642 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1643 let q_proj = size_in * size_q / weight_pack_factor;
1644 let k_proj = size_in * size_kv / weight_pack_factor;
1645 let v_proj = size_in * size_kv / weight_pack_factor;
1646 let o_proj = size_q * size_in / weight_pack_factor;
1647
1648 let h_size = cfg.hidden_size;
1649 let i_size = cfg.intermediate_size;
1650 let gate_proj = h_size * i_size / weight_pack_factor;
1651 let up_proj = h_size * i_size / weight_pack_factor;
1652 let down_proj = i_size * h_size / weight_pack_factor;
1653
1654 input_layernorm
1655 + post_attention_layernorm
1656 + q_proj
1657 + k_proj
1658 + v_proj
1659 + o_proj
1660 + gate_proj
1661 + up_proj
1662 + down_proj
1663 };
1664
1665 layer_sizes.push(per_layer_elems * dtype.size_in_bytes());
1666 }
1667
1668 Ok(layer_sizes)
1669 }
1670
1671 fn num_layers(&self, config: &str) -> Result<usize> {
1672 let config: MLlamaConfig = serde_json::from_str(config)?;
1673 Ok(config.text_config.num_hidden_layers)
1674 }
1675
1676 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1677 let cfg: MLlamaConfig = serde_json::from_str(config)?;
1678 let cfg = &cfg.text_config;
1679
1680 let cfg = ModelConfigMetadata {
1681 max_seq_len: cfg.max_position_embeddings,
1682 num_layers: cfg.num_hidden_layers,
1683 hidden_size: cfg.hidden_size,
1684 num_kv_heads: cfg.num_key_value_heads,
1685 num_attn_heads: cfg.num_attention_heads,
1686 sliding_window: None,
1687 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1688 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1689 };
1690
1691 Ok(Box::new(cfg))
1692 }
1693
1694 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1695 Some(vec![NonMappedSubModel::Vision])
1696 }
1697}
1698
1699pub struct Qwen2VLLoader;
1705
1706pub struct Qwen2VLPrefixer;
1707
1708impl VisionPromptPrefixer for Qwen2VLPrefixer {
1709 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
1710 format!(
1711 "{}{prompt}",
1712 format!(
1713 "{}{}{}",
1714 Qwen2VLProcessor::VISION_START,
1715 Qwen2VLProcessor::IMAGE_PAD,
1716 Qwen2VLProcessor::VISION_END
1717 )
1718 .repeat(image_indexes.len())
1719 )
1720 }
1721}
1722
1723impl VisionModelLoader for Qwen2VLLoader {
1724 fn load(
1725 &self,
1726 config: &str,
1727 _use_flash_attn: bool,
1728 vb: ShardedVarBuilder,
1729 normal_loading_metadata: NormalLoadingMetadata,
1730 attention_mechanism: AttentionImplementation,
1731 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
1732 let config: Qwen2VLConfig = serde_json::from_str(config)?;
1733 Ok(Box::new(Qwen2VLModel::new(
1734 &config,
1735 vb,
1736 self.is_gptx(),
1737 normal_loading_metadata,
1738 attention_mechanism,
1739 )?))
1740 }
1741 fn is_gptx(&self) -> bool {
1742 true
1743 }
1744 fn get_config_repr(&self, config: &str, _use_flash_attn: bool) -> Result<Box<dyn Debug>> {
1745 let config: Qwen2VLConfig = serde_json::from_str(config)?;
1746 Ok(Box::new(config))
1747 }
1748 fn get_processor(
1749 &self,
1750 _model_config: &str,
1751 _processor_config: Option<ProcessorConfig>,
1752 _preprocessor_config: PreProcessorConfig,
1753 max_edge: Option<u32>,
1754 ) -> Arc<dyn Processor + Send + Sync> {
1755 Arc::new(Qwen2VLProcessor::new(max_edge))
1756 }
1757 fn supports_paged_attention(&self) -> bool {
1758 false
1759 }
1760 fn prefixer(&self) -> Arc<dyn VisionPromptPrefixer> {
1761 Arc::new(Qwen2VLPrefixer)
1762 }
1763}
1764
1765impl IsqModelLoader for Qwen2VLLoader {
1766 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1767 Ok(vec![
1768 Regex::new(r"lm_head\.(weight|bias)$")?,
1769 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1771 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1772 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1773 Regex::new(r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$")?,
1774 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1776 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1777 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1778 ])
1779 }
1780}
1781
1782impl DeviceMappedModelLoader for Qwen2VLLoader {
1783 fn mapped_max_act_size_elems(
1784 &self,
1785 config: &str,
1786 params: &AutoDeviceMapParams,
1787 _prompt_chunksize: usize,
1788 ) -> Result<usize> {
1789 let AutoDeviceMapParams::Vision {
1790 max_seq_len,
1791 max_batch_size,
1792 max_image_shape,
1793 max_num_images,
1794 } = params
1795 else {
1796 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1797 };
1798
1799 let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
1800
1801 let img_seq_len = {
1802 let cfg = &cfg.vision_config;
1803 let grid_t = max_num_images / cfg.temporal_patch_size;
1804 let grid_h = max_image_shape.0 / cfg.patch_size;
1805 let grid_w = max_image_shape.1 / cfg.patch_size;
1806 grid_t * grid_h * grid_w
1807 };
1808 let img_seq_len = img_seq_len * max_num_images;
1809
1810 let max_text_attn = {
1811 let max_seq_len = img_seq_len + max_seq_len;
1813 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
1814 };
1815
1816 Ok(max_text_attn)
1817 }
1818
1819 fn non_mapped_max_act_size_elems(
1820 &self,
1821 config: &str,
1822 params: &AutoDeviceMapParams,
1823 ) -> Result<usize> {
1824 let AutoDeviceMapParams::Vision {
1825 max_seq_len: _,
1826 max_batch_size,
1827 max_image_shape,
1828 max_num_images,
1829 } = params
1830 else {
1831 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1832 };
1833
1834 let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
1835
1836 let img_seq_len = {
1837 let cfg = &cfg.vision_config;
1838 let grid_t = max_num_images / cfg.temporal_patch_size;
1839 let grid_h = max_image_shape.0 / cfg.patch_size;
1840 let grid_w = max_image_shape.1 / cfg.patch_size;
1841 grid_t * grid_h * grid_w
1842 };
1843
1844 let max_vision_attn = {
1845 let cfg = &cfg.vision_config;
1846 (max_batch_size * max_num_images) * cfg.num_heads * img_seq_len * img_seq_len
1847 };
1848
1849 Ok(max_vision_attn)
1850 }
1851
1852 fn non_mapped_size_in_bytes(
1853 &self,
1854 config: &str,
1855 dtype: DType,
1856 weight_pack_factor: usize,
1857 ) -> Result<usize> {
1858 let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
1859 let text_elems = {
1860 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1861 let lm_head = if !cfg.tie_word_embeddings {
1862 cfg.hidden_size * cfg.vocab_size
1863 } else {
1864 0
1865 };
1866 let norm = cfg.hidden_size;
1867 embed_tokens + lm_head + norm
1868 };
1869
1870 let patch_merger = {
1871 let cfg = &cfg.vision_config;
1872 let hidden_size = cfg.embed_dim * cfg.spatial_merge_size.pow(2);
1873
1874 let mlp0 = hidden_size * hidden_size + hidden_size;
1875 let mlp2 = hidden_size * cfg.hidden_size + cfg.hidden_size;
1876
1877 let ln_q = cfg.embed_dim + bias_if!(true, cfg.embed_dim);
1878
1879 mlp0 + mlp2 + ln_q
1880 };
1881
1882 let patch_embed = {
1883 let cfg = &cfg.vision_config;
1884 let conv_cfg = Conv3dConfig {
1885 stride: cfg.patch_size,
1886 ..Default::default()
1887 };
1888 let kernel_sizes = [cfg.temporal_patch_size, cfg.patch_size, cfg.patch_size];
1889 cfg.in_channels * cfg.embed_dim / conv_cfg.groups
1890 * kernel_sizes[0]
1891 * kernel_sizes[1]
1892 * kernel_sizes[2]
1893 };
1894
1895 let encoder_layer = {
1896 let cfg = &cfg.vision_config;
1897 let norm1 = cfg.embed_dim + bias_if!(true, cfg.embed_dim);
1898 let norm2 = cfg.embed_dim + bias_if!(true, cfg.embed_dim);
1899
1900 #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
1901 let mlp_hidden_dim = (cfg.embed_dim as f64 * cfg.mlp_ratio) as usize;
1902 let fc1 = cfg.embed_dim * mlp_hidden_dim + mlp_hidden_dim;
1903 let fc2 = cfg.embed_dim * mlp_hidden_dim + cfg.embed_dim;
1904
1905 let qkv = cfg.embed_dim * cfg.embed_dim * 3 + cfg.embed_dim * 3;
1906 let out = cfg.embed_dim * cfg.embed_dim + cfg.embed_dim;
1907
1908 norm1 + norm2 + fc1 + fc2 + qkv + out
1909 };
1910
1911 let elems =
1912 text_elems + patch_merger + patch_embed + encoder_layer * cfg.vision_config.depth;
1913
1914 Ok(elems * dtype.size_in_bytes())
1915 }
1916
1917 fn layer_sizes_in_bytes(
1918 &self,
1919 config: &str,
1920 dtype: DType,
1921 weight_pack_factor: usize,
1922 ) -> Result<Vec<usize>> {
1923 let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
1924 let per_layer_elems = {
1925 let input_layernorm = cfg.hidden_size;
1926 let post_attention_layernorm = cfg.hidden_size;
1927
1928 let size_in = cfg.hidden_size;
1929 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1930 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1931 let q_proj = size_in * size_q / weight_pack_factor + size_q;
1932 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
1933 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
1934 let o_proj = size_q * size_in / weight_pack_factor;
1935
1936 let h_size = cfg.hidden_size;
1937 let i_size = cfg.intermediate_size;
1938 let gate_proj = h_size * i_size / weight_pack_factor;
1939 let up_proj = h_size * i_size / weight_pack_factor;
1940 let down_proj = i_size * h_size / weight_pack_factor;
1941
1942 input_layernorm
1943 + post_attention_layernorm
1944 + q_proj
1945 + k_proj
1946 + v_proj
1947 + o_proj
1948 + gate_proj
1949 + up_proj
1950 + down_proj
1951 };
1952 Ok(vec![
1953 per_layer_elems * dtype.size_in_bytes();
1954 cfg.num_hidden_layers
1955 ])
1956 }
1957
1958 fn num_layers(&self, config: &str) -> Result<usize> {
1959 let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
1960 Ok(cfg.num_hidden_layers)
1961 }
1962
1963 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1964 let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
1965
1966 let cfg = ModelConfigMetadata {
1967 max_seq_len: cfg.max_position_embeddings,
1968 num_layers: cfg.num_hidden_layers,
1969 hidden_size: cfg.hidden_size,
1970 num_kv_heads: cfg.num_key_value_heads,
1971 num_attn_heads: cfg.num_attention_heads,
1972 sliding_window: cfg.sliding_window,
1973 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1974 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1975 };
1976
1977 Ok(Box::new(cfg))
1978 }
1979
1980 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1981 Some(vec![NonMappedSubModel::Vision])
1982 }
1983}
1984
1985pub struct Idefics3Loader;
1991
1992pub struct Idefics3Prefixer;
1993
1994impl VisionPromptPrefixer for Idefics3Prefixer {
1995 fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
1996 prompt.to_string()
1998 }
1999}
2000
2001impl VisionModelLoader for Idefics3Loader {
2002 fn load(
2003 &self,
2004 config: &str,
2005 use_flash_attn: bool,
2006 vb: ShardedVarBuilder,
2007 normal_loading_metadata: NormalLoadingMetadata,
2008 attention_mechanism: AttentionImplementation,
2009 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2010 let mut config: Idefics3Config = serde_json::from_str(config)?;
2011 config.text_config.use_flash_attn = use_flash_attn;
2012 Ok(Box::new(Idefics3Model::new(
2013 &config,
2014 vb,
2015 self.is_gptx(),
2016 normal_loading_metadata,
2017 attention_mechanism,
2018 )?))
2019 }
2020 fn is_gptx(&self) -> bool {
2021 true
2022 }
2023 fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
2024 let mut config: Idefics3Config = serde_json::from_str(config)?;
2025 config.text_config.use_flash_attn = use_flash_attn;
2026 Ok(Box::new(config))
2027 }
2028 fn get_processor(
2029 &self,
2030 _model_config: &str,
2031 processor_config: Option<ProcessorConfig>,
2032 preprocessor_config: PreProcessorConfig,
2033 max_edge: Option<u32>,
2034 ) -> Arc<dyn Processor + Send + Sync> {
2035 Arc::new(Idefics3Processor::new(
2036 processor_config.unwrap_or_default(),
2037 preprocessor_config,
2038 max_edge,
2039 ))
2040 }
2041 fn supports_paged_attention(&self) -> bool {
2042 true
2043 }
2044 fn prefixer(&self) -> Arc<dyn VisionPromptPrefixer> {
2045 Arc::new(Idefics3Prefixer)
2046 }
2047}
2048
2049impl IsqModelLoader for Idefics3Loader {
2050 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2051 Ok(vec![
2052 Regex::new(r"lm_head\.(weight|bias)$")?,
2053 Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2055 Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2056 Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2057 Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2058 Regex::new(r"model.text_model.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2060 Regex::new(r"model.text_model.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2061 Regex::new(r"model.text_model.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2062 ])
2063 }
2064}
2065
2066impl DeviceMappedModelLoader for Idefics3Loader {
2067 fn mapped_max_act_size_elems(
2068 &self,
2069 config: &str,
2070 params: &AutoDeviceMapParams,
2071 _prompt_chunksize: usize,
2072 ) -> Result<usize> {
2073 let AutoDeviceMapParams::Vision {
2074 max_seq_len,
2075 max_batch_size,
2076 max_image_shape: _,
2077 max_num_images,
2078 } = params
2079 else {
2080 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2081 };
2082
2083 let cfg: Idefics3Config = serde_json::from_str(config)?;
2084
2085 let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2086 let img_seq_len = (num_patches + 1) * max_num_images;
2087
2088 let max_text_attn = {
2089 let max_seq_len = img_seq_len + max_seq_len;
2091 max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len
2092 };
2093
2094 Ok(max_text_attn)
2095 }
2096
2097 fn non_mapped_max_act_size_elems(
2098 &self,
2099 config: &str,
2100 params: &AutoDeviceMapParams,
2101 ) -> Result<usize> {
2102 let AutoDeviceMapParams::Vision {
2103 max_seq_len: _,
2104 max_batch_size,
2105 max_image_shape: _,
2106 max_num_images,
2107 } = params
2108 else {
2109 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2110 };
2111
2112 let cfg: Idefics3Config = serde_json::from_str(config)?;
2113
2114 let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2115 let img_seq_len = num_patches + 1;
2116
2117 let max_vision_attn = {
2118 let images_factor = 5;
2120
2121 (max_batch_size * images_factor * max_num_images)
2122 * cfg.vision_config.num_attention_heads
2123 * img_seq_len
2124 * img_seq_len
2125 };
2126
2127 Ok(max_vision_attn)
2128 }
2129
2130 fn non_mapped_size_in_bytes(
2131 &self,
2132 config: &str,
2133 dtype: DType,
2134 weight_pack_factor: usize,
2135 ) -> Result<usize> {
2136 let cfg: Idefics3Config = serde_json::from_str(config)?;
2137 let text_elems = {
2138 let cfg = &cfg.text_config;
2139
2140 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2141 let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2142 let norm = cfg.hidden_size;
2143 embed_tokens + lm_head + norm
2144 };
2145
2146 let connector_elems = {
2147 let in_dim = cfg.vision_config.hidden_size * cfg.scale_factor.pow(2);
2148 let out_dim = cfg.text_config.hidden_size;
2149
2150 in_dim * out_dim
2151 };
2152
2153 let vision_transformer = {
2154 let cfg = &cfg.vision_config;
2155
2156 let post_layernorm = cfg.hidden_size;
2157
2158 let conv_config = Conv2dConfig {
2159 stride: cfg.patch_size,
2160 ..Default::default()
2161 };
2162 let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
2163 * cfg.patch_size
2164 * cfg.patch_size;
2165
2166 let num_patches_per_side = cfg.image_size / cfg.patch_size;
2167 let num_patches = num_patches_per_side.pow(2);
2168 let position_embedding = num_patches * cfg.hidden_size;
2169
2170 let layer_elems = {
2171 let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2172 let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2173
2174 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
2175 let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
2176
2177 let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2178 let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2179 let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2180 let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2181
2182 layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
2183 };
2184
2185 post_layernorm
2186 + patch_embedding
2187 + position_embedding
2188 + layer_elems * cfg.num_hidden_layers
2189 };
2190
2191 let elems = text_elems + connector_elems + vision_transformer;
2192
2193 Ok(elems * dtype.size_in_bytes())
2194 }
2195
2196 fn layer_sizes_in_bytes(
2197 &self,
2198 config: &str,
2199 dtype: DType,
2200 weight_pack_factor: usize,
2201 ) -> Result<Vec<usize>> {
2202 let cfg: Idefics3Config = serde_json::from_str(config)?;
2203 let cfg = cfg.text_config;
2204 let per_layer_elems = {
2205 let input_layernorm = cfg.hidden_size;
2206 let post_attention_layernorm = cfg.hidden_size;
2207
2208 let size_in = cfg.hidden_size;
2209 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2210 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2211 let q_proj = size_in * size_q / weight_pack_factor;
2212 let k_proj = size_in * size_kv / weight_pack_factor;
2213 let v_proj = size_in * size_kv / weight_pack_factor;
2214 let o_proj = size_q * size_in / weight_pack_factor;
2215
2216 let h_size = cfg.hidden_size;
2217 let i_size = cfg.intermediate_size;
2218 let gate_proj = h_size * i_size / weight_pack_factor;
2219 let up_proj = h_size * i_size / weight_pack_factor;
2220 let down_proj = i_size * h_size / weight_pack_factor;
2221
2222 input_layernorm
2223 + post_attention_layernorm
2224 + q_proj
2225 + k_proj
2226 + v_proj
2227 + o_proj
2228 + gate_proj
2229 + up_proj
2230 + down_proj
2231 };
2232 Ok(vec![
2233 per_layer_elems * dtype.size_in_bytes();
2234 cfg.num_hidden_layers
2235 ])
2236 }
2237
2238 fn num_layers(&self, config: &str) -> Result<usize> {
2239 let cfg: Idefics3Config = serde_json::from_str(config)?;
2240 Ok(cfg.text_config.num_hidden_layers)
2241 }
2242 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2243 let cfg: Idefics3Config = serde_json::from_str(config)?;
2244 let cfg = &cfg.text_config;
2245
2246 let cfg = ModelConfigMetadata {
2247 max_seq_len: cfg.max_position_embeddings,
2248 num_layers: cfg.num_hidden_layers,
2249 hidden_size: cfg.hidden_size,
2250 num_kv_heads: cfg.num_key_value_heads,
2251 num_attn_heads: cfg.num_attention_heads,
2252 sliding_window: None,
2253 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2254 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2255 };
2256
2257 Ok(Box::new(cfg))
2258 }
2259
2260 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
2261 Some(vec![NonMappedSubModel::Vision])
2262 }
2263}
2264
2265pub struct MiniCpmOLoader;
2271
2272pub struct MiniCpmOPrefixer;
2273
2274impl VisionPromptPrefixer for MiniCpmOPrefixer {
2275 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
2276 format!(
2277 "{}{prompt}",
2278 "(<image>./</image>)".repeat(image_indexes.len())
2279 )
2280 }
2281}
2282
2283impl VisionModelLoader for MiniCpmOLoader {
2284 fn load(
2285 &self,
2286 config: &str,
2287 use_flash_attn: bool,
2288 vb: ShardedVarBuilder,
2289 normal_loading_metadata: NormalLoadingMetadata,
2290 attention_mechanism: AttentionImplementation,
2291 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2292 let mut config: MiniCpmOConfig = serde_json::from_str(config)?;
2293 config.text_config.use_flash_attn = use_flash_attn;
2294 Ok(Box::new(MiniCpmOModel::new(
2295 &config,
2296 vb,
2297 self.is_gptx(),
2298 normal_loading_metadata,
2299 attention_mechanism,
2300 )?))
2301 }
2302 fn is_gptx(&self) -> bool {
2303 true
2304 }
2305 fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
2306 let mut config: MiniCpmOConfig = serde_json::from_str(config)?;
2307 config.text_config.use_flash_attn = use_flash_attn;
2308 Ok(Box::new(config))
2309 }
2310 fn get_processor(
2311 &self,
2312 _model_config: &str,
2313 processor_config: Option<ProcessorConfig>,
2314 preprocessor_config: PreProcessorConfig,
2315 max_edge: Option<u32>,
2316 ) -> Arc<dyn Processor + Send + Sync> {
2317 Arc::new(MiniCpmOProcessor::new(
2318 processor_config.unwrap_or_default(),
2319 preprocessor_config,
2320 max_edge,
2321 ))
2322 }
2323 fn supports_paged_attention(&self) -> bool {
2324 true
2325 }
2326 fn prefixer(&self) -> Arc<dyn VisionPromptPrefixer> {
2327 Arc::new(MiniCpmOPrefixer)
2328 }
2329}
2330
2331impl IsqModelLoader for MiniCpmOLoader {
2332 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2333 Ok(vec![
2334 Regex::new(r"llm.lm_head\.(weight|bias)$")?,
2335 Regex::new(r"llm.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2337 Regex::new(r"llm.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2338 Regex::new(r"llm.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2339 Regex::new(r"llm.layers\.(\d+)\.self_attn\.dense\.(weight|bias)$")?,
2340 Regex::new(r"llm.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2342 Regex::new(r"llm.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2343 Regex::new(r"llm.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2344 ])
2345 }
2346}
2347
2348impl DeviceMappedModelLoader for MiniCpmOLoader {
2349 fn mapped_max_act_size_elems(
2350 &self,
2351 config: &str,
2352 params: &AutoDeviceMapParams,
2353 _prompt_chunksize: usize,
2354 ) -> Result<usize> {
2355 let AutoDeviceMapParams::Vision {
2356 max_seq_len,
2357 max_batch_size,
2358 max_image_shape: _,
2359 max_num_images,
2360 } = params
2361 else {
2362 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2363 };
2364
2365 let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2366
2367 let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2368 let img_seq_len = (num_patches + 1) * max_num_images;
2369
2370 let max_text_attn = {
2371 let max_seq_len = img_seq_len + max_seq_len;
2373 max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len
2374 };
2375
2376 Ok(max_text_attn)
2377 }
2378
2379 fn non_mapped_max_act_size_elems(
2380 &self,
2381 config: &str,
2382 params: &AutoDeviceMapParams,
2383 ) -> Result<usize> {
2384 let AutoDeviceMapParams::Vision {
2385 max_seq_len: _,
2386 max_batch_size,
2387 max_image_shape: _,
2388 max_num_images,
2389 } = params
2390 else {
2391 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2392 };
2393
2394 let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2395
2396 let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2397 let img_seq_len = num_patches + 1;
2398
2399 let max_vision_attn = {
2400 let images_factor = 5;
2402
2403 (max_batch_size * images_factor * max_num_images)
2404 * cfg.vision_config.num_attention_heads
2405 * img_seq_len
2406 * img_seq_len
2407 };
2408
2409 Ok(max_vision_attn)
2410 }
2411
2412 fn non_mapped_size_in_bytes(
2413 &self,
2414 config: &str,
2415 dtype: DType,
2416 weight_pack_factor: usize,
2417 ) -> Result<usize> {
2418 let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2419 let text_elems = {
2420 let cfg = &cfg.text_config;
2421
2422 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2423 let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2424 let norm = cfg.hidden_size;
2425 embed_tokens + lm_head + norm
2426 };
2427
2428 let vision_transformer = {
2429 let cfg = &cfg.vision_config;
2430
2431 let post_layernorm = cfg.hidden_size;
2432
2433 let conv_config = Conv2dConfig {
2434 stride: cfg.patch_size,
2435 ..Default::default()
2436 };
2437 let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
2438 * cfg.patch_size
2439 * cfg.patch_size;
2440
2441 let num_patches_per_side = cfg.image_size / cfg.patch_size;
2442 let num_patches = num_patches_per_side.pow(2);
2443 let position_embedding = num_patches * cfg.hidden_size;
2444
2445 let layer_elems = {
2446 let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2447 let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2448
2449 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
2450 let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
2451
2452 let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2453 let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2454 let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2455 let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2456
2457 layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
2458 };
2459
2460 post_layernorm
2461 + patch_embedding
2462 + position_embedding
2463 + layer_elems * cfg.num_hidden_layers
2464 };
2465
2466 let elems = text_elems + vision_transformer;
2467
2468 Ok(elems * dtype.size_in_bytes())
2469 }
2470
2471 fn layer_sizes_in_bytes(
2472 &self,
2473 config: &str,
2474 dtype: DType,
2475 weight_pack_factor: usize,
2476 ) -> Result<Vec<usize>> {
2477 let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2478 let cfg = cfg.text_config;
2479 let per_layer_elems = {
2480 let input_layernorm = cfg.hidden_size;
2481 let post_attention_layernorm = cfg.hidden_size;
2482
2483 let size_in = cfg.hidden_size;
2484 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2485 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2486 let q_proj = size_in * size_q / weight_pack_factor;
2487 let k_proj = size_in * size_kv / weight_pack_factor;
2488 let v_proj = size_in * size_kv / weight_pack_factor;
2489 let o_proj = size_q * size_in / weight_pack_factor;
2490
2491 let h_size = cfg.hidden_size;
2492 let i_size = cfg.intermediate_size;
2493 let gate_proj = h_size * i_size / weight_pack_factor;
2494 let up_proj = h_size * i_size / weight_pack_factor;
2495 let down_proj = i_size * h_size / weight_pack_factor;
2496
2497 input_layernorm
2498 + post_attention_layernorm
2499 + q_proj
2500 + k_proj
2501 + v_proj
2502 + o_proj
2503 + gate_proj
2504 + up_proj
2505 + down_proj
2506 };
2507 Ok(vec![
2508 per_layer_elems * dtype.size_in_bytes();
2509 cfg.num_hidden_layers
2510 ])
2511 }
2512
2513 fn num_layers(&self, config: &str) -> Result<usize> {
2514 let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2515 Ok(cfg.text_config.num_hidden_layers)
2516 }
2517 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2518 let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2519 let cfg = &cfg.text_config;
2520
2521 let cfg = ModelConfigMetadata {
2522 max_seq_len: cfg.max_position_embeddings,
2523 num_layers: cfg.num_hidden_layers,
2524 hidden_size: cfg.hidden_size,
2525 num_kv_heads: cfg.num_key_value_heads,
2526 num_attn_heads: cfg.num_attention_heads,
2527 sliding_window: None,
2528 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2529 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2530 };
2531
2532 Ok(Box::new(cfg))
2533 }
2534}
2535
2536pub struct Phi4MMLoader;
2542
2543pub struct Phi4MMPrefixer;
2544
2545impl VisionPromptPrefixer for Phi4MMPrefixer {
2546 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
2547 format!(
2550 "{}{prompt}",
2551 image_indexes
2552 .into_iter()
2553 .map(|image_index| format!("<|image_{}|>", image_index + 1))
2554 .join("")
2555 )
2556 }
2557}
2558
2559impl VisionModelLoader for Phi4MMLoader {
2560 fn load(
2561 &self,
2562 config: &str,
2563 use_flash_attn: bool,
2564 vb: ShardedVarBuilder,
2565 normal_loading_metadata: NormalLoadingMetadata,
2566 attention_mechanism: AttentionImplementation,
2567 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2568 let mut config: Phi4MMConfig = serde_json::from_str(config)?;
2569 config.use_flash_attn = use_flash_attn;
2570 Ok(Box::new(Phi4MMModel::new(
2571 &config,
2572 vb,
2573 self.is_gptx(),
2574 normal_loading_metadata,
2575 attention_mechanism,
2576 )?))
2577 }
2578 fn is_gptx(&self) -> bool {
2579 true
2580 }
2581 fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
2582 let mut config: Phi4MMConfig = serde_json::from_str(config)?;
2583 config.use_flash_attn = use_flash_attn;
2584 Ok(Box::new(config))
2585 }
2586 fn get_processor(
2587 &self,
2588 _model_config: &str,
2589 processor_config: Option<ProcessorConfig>,
2590 preprocessor_config: PreProcessorConfig,
2591 _max_edge: Option<u32>,
2592 ) -> Arc<dyn Processor + Send + Sync> {
2593 Phi4MMProcessor::new_processor(processor_config, preprocessor_config)
2594 }
2595 fn supports_paged_attention(&self) -> bool {
2596 true
2597 }
2598 fn prefixer(&self) -> Arc<dyn VisionPromptPrefixer> {
2599 Arc::new(Phi4MMPrefixer)
2600 }
2601}
2602
2603impl IsqModelLoader for Phi4MMLoader {
2604 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2605 Ok(vec![
2606 Regex::new(r"lm_head\.(weight|bias)$")?,
2607 Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?,
2609 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2610 Regex::new(r"layers\.(\d+)\.mlp\.gate_up_proj\.(weight|bias)$")?,
2612 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2613 ])
2614 }
2615}
2616
2617impl DeviceMappedModelLoader for Phi4MMLoader {
2618 fn mapped_max_act_size_elems(
2619 &self,
2620 config: &str,
2621 params: &AutoDeviceMapParams,
2622 _prompt_chunksize: usize,
2623 ) -> Result<usize> {
2624 let AutoDeviceMapParams::Vision {
2626 max_seq_len,
2627 max_batch_size,
2628 max_image_shape: _,
2629 max_num_images,
2630 } = params
2631 else {
2632 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2633 };
2634
2635 let cfg: Phi4MMConfig = serde_json::from_str(config)?;
2636
2637 let vcfg = &PHI4_MM_VISION_CFG;
2638
2639 let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
2640 let img_seq_len = (num_patches + 1) * max_num_images;
2641
2642 let max_text_attn = {
2643 let max_seq_len = img_seq_len + max_seq_len;
2645 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
2646 };
2647
2648 Ok(max_text_attn)
2649 }
2650
2651 fn non_mapped_max_act_size_elems(
2652 &self,
2653 _config: &str,
2654 params: &AutoDeviceMapParams,
2655 ) -> Result<usize> {
2656 let AutoDeviceMapParams::Vision {
2657 max_seq_len: _,
2658 max_batch_size,
2659 max_image_shape,
2660 max_num_images,
2661 } = params
2662 else {
2663 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2664 };
2665
2666 let vcfg = &PHI4_MM_VISION_CFG;
2667
2668 let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
2669 let img_seq_len = num_patches + 1;
2670
2671 let max_batch_size = max_batch_size
2672 * (max_image_shape
2673 .0
2674 .div_ceil(phi4::inputs_processor::DYHD_BASE_RESOLUTION)
2675 * max_image_shape
2676 .1
2677 .div_ceil(phi4::inputs_processor::DYHD_BASE_RESOLUTION)
2678 + 1);
2679
2680 let max_vision_attn = (max_batch_size * max_num_images)
2681 * vcfg.num_attention_heads
2682 * img_seq_len
2683 * img_seq_len;
2684 let max_qkv = 3
2685 * (max_batch_size
2686 * vcfg.num_attention_heads
2687 * img_seq_len
2688 * (vcfg.hidden_size / vcfg.num_attention_heads));
2689
2690 Ok(max_vision_attn + max_qkv)
2691 }
2692
2693 fn non_mapped_size_in_bytes(
2694 &self,
2695 config: &str,
2696 dtype: DType,
2697 weight_pack_factor: usize,
2698 ) -> Result<usize> {
2699 let cfg: Phi4MMConfig = serde_json::from_str(config)?;
2700 let elems = {
2701 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2702 let lm_head = if !cfg.tie_word_embeddings {
2703 cfg.hidden_size * cfg.vocab_size
2704 } else {
2705 0
2706 };
2707 let norm = cfg.hidden_size;
2708
2709 let image_embed = if let Some(img_embed) = &cfg.embd_layer.image_embd_layer {
2710 let projection_cls = img_embed
2711 .projection_cls
2712 .clone()
2713 .unwrap_or("linear".to_string());
2714 let with_learnable_separator = img_embed.with_learnable_separator.unwrap_or(false);
2715 let use_hd_transform = img_embed.use_hd_transform.unwrap_or(false);
2716 let image_dim_out = PHI4_MM_VISION_CFG.hidden_size;
2717
2718 let proj = match (projection_cls.as_str(), use_hd_transform) {
2719 ("linear", _) => image_dim_out * cfg.hidden_size + cfg.hidden_size,
2720 ("mlp", true) => {
2721 let a = (image_dim_out * 4) * cfg.hidden_size + cfg.hidden_size;
2722 let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2723 a + b
2724 }
2725 ("mlp", false) => {
2726 let a = image_dim_out * cfg.hidden_size + cfg.hidden_size;
2727 let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2728 a + b
2729 }
2730 _ => {
2731 anyhow::bail!("projection_cls=`{projection_cls}` not implemented.");
2732 }
2733 };
2734
2735 let (glb_gn, sub_gn) = if with_learnable_separator {
2736 let glb_gn = image_dim_out * 4;
2737 let sub_gn = image_dim_out * 4;
2738 (glb_gn, sub_gn)
2739 } else {
2740 (0, 0)
2741 };
2742
2743 let vision_transformer = {
2744 let cfg = &PHI4_MM_VISION_CFG;
2745
2746 let post_layernorm = cfg.hidden_size;
2747
2748 let conv_config = Conv2dConfig {
2749 stride: cfg.patch_size,
2750 ..Default::default()
2751 };
2752 let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
2753 * cfg.patch_size
2754 * cfg.patch_size;
2755
2756 let num_patches_per_side = cfg.image_size / cfg.patch_size;
2757 let num_patches = num_patches_per_side.pow(2);
2758 let position_embedding = num_patches * cfg.hidden_size;
2759
2760 let layer_elems = {
2761 let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2762 let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2763
2764 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
2765 let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
2766
2767 let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2768 let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2769 let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2770 let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2771
2772 layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
2773 };
2774
2775 post_layernorm
2776 + patch_embedding
2777 + position_embedding
2778 + layer_elems * cfg.num_hidden_layers
2779 };
2780
2781 proj + glb_gn + sub_gn + vision_transformer
2782 } else {
2783 0
2784 };
2785
2786 embed_tokens + lm_head + norm + image_embed
2787 };
2788
2789 Ok(elems * dtype.size_in_bytes())
2790 }
2791
2792 fn layer_sizes_in_bytes(
2793 &self,
2794 config: &str,
2795 dtype: DType,
2796 weight_pack_factor: usize,
2797 ) -> Result<Vec<usize>> {
2798 let cfg: Phi4MMConfig = serde_json::from_str(config)?;
2799 let per_layer_elems = {
2800 let input_layernorm = cfg.hidden_size;
2801 let post_attention_layernorm = cfg.hidden_size;
2802
2803 let size_in = cfg.hidden_size;
2804 let head_dim = cfg.head_dim();
2805 let op_size = head_dim * head_dim + 2 * cfg.num_key_value_heads() * head_dim;
2806 let qkv_proj = size_in * op_size / weight_pack_factor;
2807 let o_proj = (cfg.num_attention_heads * head_dim) * size_in / weight_pack_factor;
2808
2809 let h_size = cfg.hidden_size;
2810 let i_size = cfg.intermediate_size;
2811 let gate_up_proj = h_size * (2 * i_size) / weight_pack_factor;
2812 let down_proj = h_size * i_size / weight_pack_factor;
2813
2814 input_layernorm
2815 + post_attention_layernorm
2816 + qkv_proj
2817 + o_proj
2818 + gate_up_proj
2819 + down_proj
2820 };
2821 Ok(vec![
2822 per_layer_elems * dtype.size_in_bytes();
2823 cfg.num_hidden_layers
2824 ])
2825 }
2826
2827 fn num_layers(&self, config: &str) -> Result<usize> {
2828 let cfg: Phi4MMConfig = serde_json::from_str(config)?;
2829 Ok(cfg.num_hidden_layers)
2830 }
2831
2832 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2833 let cfg: Phi4MMConfig = serde_json::from_str(config)?;
2834
2835 let cfg = ModelConfigMetadata {
2836 max_seq_len: cfg.max_position_embeddings,
2837 num_layers: cfg.num_hidden_layers,
2838 hidden_size: cfg.hidden_size,
2839 num_kv_heads: cfg.num_key_value_heads(),
2840 num_attn_heads: cfg.num_attention_heads,
2841 sliding_window: cfg.sliding_window,
2842 k_head_dim: cfg.head_dim(),
2843 v_head_dim: cfg.head_dim(),
2844 };
2845
2846 Ok(Box::new(cfg))
2847 }
2848
2849 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
2850 Some(vec![NonMappedSubModel::Vision])
2851 }
2852}
2853
2854pub struct Qwen2_5VLLoader;
2860
2861pub struct Qwen2_5VLPrefixer;
2862
2863impl VisionPromptPrefixer for Qwen2_5VLPrefixer {
2864 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
2865 format!(
2866 "{}{prompt}",
2867 format!(
2868 "{}{}{}",
2869 Qwen2_5VLProcessor::VISION_START,
2870 Qwen2_5VLProcessor::IMAGE_PAD,
2871 Qwen2_5VLProcessor::VISION_END
2872 )
2873 .repeat(image_indexes.len())
2874 )
2875 }
2876}
2877
2878impl VisionModelLoader for Qwen2_5VLLoader {
2879 fn load(
2880 &self,
2881 config: &str,
2882 _use_flash_attn: bool,
2883 vb: ShardedVarBuilder,
2884 normal_loading_metadata: NormalLoadingMetadata,
2885 attention_mechanism: AttentionImplementation,
2886 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2887 let config: Qwen2_5VLConfig = serde_json::from_str(config)?;
2888 Ok(Box::new(Qwen2_5VLModel::new(
2889 &config,
2890 vb,
2891 self.is_gptx(),
2892 normal_loading_metadata,
2893 attention_mechanism,
2894 )?))
2895 }
2896 fn is_gptx(&self) -> bool {
2897 true
2898 }
2899 fn get_config_repr(&self, config: &str, _use_flash_attn: bool) -> Result<Box<dyn Debug>> {
2900 let config: Qwen2_5VLConfig = serde_json::from_str(config)?;
2901 Ok(Box::new(config))
2902 }
2903 fn get_processor(
2904 &self,
2905 _model_config: &str,
2906 _processor_config: Option<ProcessorConfig>,
2907 _preprocessor_config: PreProcessorConfig,
2908 max_edge: Option<u32>,
2909 ) -> Arc<dyn Processor + Send + Sync> {
2910 Arc::new(Qwen2_5VLProcessor::new(max_edge))
2911 }
2912 fn supports_paged_attention(&self) -> bool {
2913 false
2914 }
2915 fn prefixer(&self) -> Arc<dyn VisionPromptPrefixer> {
2916 Arc::new(Qwen2_5VLPrefixer)
2917 }
2918}
2919
2920impl IsqModelLoader for Qwen2_5VLLoader {
2921 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2922 Ok(vec![
2923 Regex::new(r"lm_head\.(weight|bias)$")?,
2924 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2926 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2927 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2928 Regex::new(r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$")?,
2929 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2931 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2932 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2933 ])
2934 }
2935}
2936
2937impl DeviceMappedModelLoader for Qwen2_5VLLoader {
2938 fn mapped_max_act_size_elems(
2939 &self,
2940 config: &str,
2941 params: &AutoDeviceMapParams,
2942 _prompt_chunksize: usize,
2943 ) -> Result<usize> {
2944 let AutoDeviceMapParams::Vision {
2945 max_seq_len,
2946 max_batch_size,
2947 max_image_shape,
2948 max_num_images,
2949 } = params
2950 else {
2951 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2952 };
2953
2954 let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
2955
2956 let img_seq_len = {
2957 let cfg = &cfg.vision_config;
2958 let grid_t = max_num_images / cfg.temporal_patch_size;
2959 let grid_h = max_image_shape.0 / cfg.patch_size;
2960 let grid_w = max_image_shape.1 / cfg.patch_size;
2961 grid_t * grid_h * grid_w
2962 };
2963 let img_seq_len = img_seq_len * max_num_images;
2964
2965 let max_text_attn = {
2966 let max_seq_len = img_seq_len + max_seq_len;
2968 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
2969 };
2970
2971 Ok(max_text_attn)
2972 }
2973
2974 fn non_mapped_max_act_size_elems(
2975 &self,
2976 config: &str,
2977 params: &AutoDeviceMapParams,
2978 ) -> Result<usize> {
2979 let AutoDeviceMapParams::Vision {
2980 max_seq_len: _,
2981 max_batch_size,
2982 max_image_shape,
2983 max_num_images,
2984 } = params
2985 else {
2986 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2987 };
2988
2989 let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
2990
2991 let img_seq_len = {
2992 let cfg = &cfg.vision_config;
2993 let grid_t = max_num_images / cfg.temporal_patch_size;
2994 let grid_h = max_image_shape.0 / cfg.patch_size;
2995 let grid_w = max_image_shape.1 / cfg.patch_size;
2996 grid_t * grid_h * grid_w
2997 };
2998
2999 let max_vision_attn = {
3000 let cfg = &cfg.vision_config;
3001 (max_batch_size * max_num_images) * cfg.num_heads * img_seq_len * img_seq_len
3002 };
3003
3004 Ok(max_vision_attn)
3005 }
3006
3007 fn non_mapped_size_in_bytes(
3008 &self,
3009 config: &str,
3010 dtype: DType,
3011 weight_pack_factor: usize,
3012 ) -> Result<usize> {
3013 let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3014 let text_elems = {
3015 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3016 let lm_head = if !cfg.tie_word_embeddings {
3017 cfg.hidden_size * cfg.vocab_size
3018 } else {
3019 0
3020 };
3021 let norm = cfg.hidden_size;
3022 embed_tokens + lm_head + norm
3023 };
3024
3025 let patch_merger = {
3026 let cfg = &cfg.vision_config;
3027 let hidden_size = cfg.hidden_size * cfg.spatial_merge_size.pow(2);
3028
3029 let mlp0 = hidden_size * hidden_size + hidden_size;
3030 let mlp2 = hidden_size * cfg.hidden_size + cfg.hidden_size;
3031
3032 let ln_q = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3033
3034 mlp0 + mlp2 + ln_q
3035 };
3036
3037 let patch_embed = {
3038 let cfg = &cfg.vision_config;
3039 let conv_cfg = Conv3dConfig {
3040 stride: cfg.patch_size,
3041 ..Default::default()
3042 };
3043 let kernel_sizes = [cfg.temporal_patch_size, cfg.patch_size, cfg.patch_size];
3044 cfg.in_chans * cfg.hidden_size / conv_cfg.groups
3045 * kernel_sizes[0]
3046 * kernel_sizes[1]
3047 * kernel_sizes[2]
3048 };
3049
3050 let encoder_layer = {
3051 let cfg = &cfg.vision_config;
3052 let norm1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3053 let norm2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3054
3055 #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
3056 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
3057 let fc2 = cfg.hidden_size * cfg.intermediate_size + cfg.hidden_size;
3058
3059 let qkv = cfg.hidden_size * cfg.hidden_size * 3 + cfg.hidden_size * 3;
3060 let out = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3061
3062 norm1 + norm2 + fc1 + fc2 + qkv + out
3063 };
3064
3065 let elems =
3066 text_elems + patch_merger + patch_embed + encoder_layer * cfg.vision_config.depth;
3067
3068 Ok(elems * dtype.size_in_bytes())
3069 }
3070
3071 fn layer_sizes_in_bytes(
3072 &self,
3073 config: &str,
3074 dtype: DType,
3075 weight_pack_factor: usize,
3076 ) -> Result<Vec<usize>> {
3077 let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3078 let per_layer_elems = {
3079 let input_layernorm = cfg.hidden_size;
3080 let post_attention_layernorm = cfg.hidden_size;
3081
3082 let size_in = cfg.hidden_size;
3083 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
3084 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
3085 let q_proj = size_in * size_q / weight_pack_factor + size_q;
3086 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
3087 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
3088 let o_proj = size_q * size_in / weight_pack_factor;
3089
3090 let h_size = cfg.hidden_size;
3091 let i_size = cfg.intermediate_size;
3092 let gate_proj = h_size * i_size / weight_pack_factor;
3093 let up_proj = h_size * i_size / weight_pack_factor;
3094 let down_proj = i_size * h_size / weight_pack_factor;
3095
3096 input_layernorm
3097 + post_attention_layernorm
3098 + q_proj
3099 + k_proj
3100 + v_proj
3101 + o_proj
3102 + gate_proj
3103 + up_proj
3104 + down_proj
3105 };
3106 Ok(vec![
3107 per_layer_elems * dtype.size_in_bytes();
3108 cfg.num_hidden_layers
3109 ])
3110 }
3111
3112 fn num_layers(&self, config: &str) -> Result<usize> {
3113 let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3114 Ok(cfg.num_hidden_layers)
3115 }
3116
3117 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3118 let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3119
3120 let cfg = ModelConfigMetadata {
3121 max_seq_len: cfg.max_position_embeddings,
3122 num_layers: cfg.num_hidden_layers,
3123 hidden_size: cfg.hidden_size,
3124 num_kv_heads: cfg.num_key_value_heads,
3125 num_attn_heads: cfg.num_attention_heads,
3126 sliding_window: cfg.sliding_window,
3127 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3128 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3129 };
3130
3131 Ok(Box::new(cfg))
3132 }
3133
3134 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
3135 Some(vec![NonMappedSubModel::Vision])
3136 }
3137}
3138
3139pub struct Gemma3Loader;
3145
3146pub struct Gemma3Prefixer;
3147
3148impl VisionPromptPrefixer for Gemma3Prefixer {
3149 fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
3150 prompt.to_string()
3151 }
3152}
3153
3154impl VisionModelLoader for Gemma3Loader {
3155 fn load(
3156 &self,
3157 config: &str,
3158 _use_flash_attn: bool,
3159 vb: ShardedVarBuilder,
3160 normal_loading_metadata: NormalLoadingMetadata,
3161 attention_mechanism: AttentionImplementation,
3162 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
3163 let config: Gemma3Config = serde_json::from_str(config)?;
3164 Ok(Box::new(Gemma3Model::new(
3165 &config,
3166 vb,
3167 self.is_gptx(),
3168 normal_loading_metadata,
3169 attention_mechanism,
3170 )?))
3171 }
3172 fn is_gptx(&self) -> bool {
3173 true
3174 }
3175 fn get_config_repr(&self, config: &str, _use_flash_attn: bool) -> Result<Box<dyn Debug>> {
3176 let config: Gemma3Config = serde_json::from_str(config)?;
3177 Ok(Box::new(config))
3178 }
3179 fn get_processor(
3180 &self,
3181 config: &str,
3182 processor_config: Option<ProcessorConfig>,
3183 _preprocessor_config: PreProcessorConfig,
3184 _max_edge: Option<u32>,
3185 ) -> Arc<dyn Processor + Send + Sync> {
3186 let config: Gemma3Config = serde_json::from_str(config).unwrap();
3187 Arc::new(Gemma3Processor::new(
3189 processor_config.unwrap_or_default(),
3190 matches!(config, Gemma3Config::WithVision { .. }),
3191 ))
3192 }
3193 fn supports_paged_attention(&self) -> bool {
3194 true
3195 }
3196 fn prefixer(&self) -> Arc<dyn VisionPromptPrefixer> {
3197 Arc::new(Gemma3Prefixer)
3198 }
3199}
3200
3201impl IsqModelLoader for Gemma3Loader {
3202 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3203 Ok(vec![
3204 Regex::new(r"lm_head\.(weight|bias)$")?,
3205 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3207 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3208 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3209 Regex::new(r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$")?,
3210 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3212 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3213 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3214 ])
3215 }
3216}
3217
3218impl DeviceMappedModelLoader for Gemma3Loader {
3219 fn mapped_max_act_size_elems(
3220 &self,
3221 config: &str,
3222 params: &AutoDeviceMapParams,
3223 prompt_chunksize: usize,
3224 ) -> Result<usize> {
3225 let AutoDeviceMapParams::Vision {
3226 max_seq_len,
3227 max_batch_size,
3228 max_image_shape: _,
3229 max_num_images,
3230 } = params
3231 else {
3232 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3233 };
3234
3235 let cfg: Gemma3Config = serde_json::from_str(config)?;
3236
3237 match cfg {
3238 Gemma3Config::Text(text_config) => Ok(max_batch_size
3239 * text_config.num_attention_heads
3240 * prompt_chunksize
3241 * prompt_chunksize),
3242 Gemma3Config::WithVision {
3243 text_config,
3244 vision_config,
3245 ..
3246 } => {
3247 let num_patches = (vision_config.image_size / vision_config.patch_size).pow(2);
3248 let img_seq_len = (num_patches + 1) * max_num_images;
3249
3250 let max_text_attn = {
3251 let max_seq_len = img_seq_len + *max_seq_len;
3253 max_batch_size * text_config.num_attention_heads * max_seq_len * max_seq_len
3254 };
3255 Ok(max_text_attn)
3256 }
3257 }
3258 }
3259
3260 fn non_mapped_max_act_size_elems(
3261 &self,
3262 config: &str,
3263 params: &AutoDeviceMapParams,
3264 ) -> Result<usize> {
3265 let AutoDeviceMapParams::Vision {
3266 max_seq_len: _,
3267 max_batch_size,
3268 max_image_shape: _,
3269 max_num_images,
3270 } = params
3271 else {
3272 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3273 };
3274
3275 let cfg: Gemma3Config = serde_json::from_str(config)?;
3276
3277 match cfg {
3278 Gemma3Config::WithVision { vision_config, .. } => {
3279 let num_patches = (vision_config.image_size / vision_config.patch_size).pow(2);
3280 let img_seq_len = num_patches + 1;
3281
3282 let max_vision_attn = {
3283 (max_batch_size * max_num_images)
3284 * vision_config.num_attention_heads
3285 * img_seq_len
3286 * img_seq_len
3287 };
3288
3289 Ok(max_vision_attn)
3290 }
3291 Gemma3Config::Text(_) => Ok(0),
3292 }
3293 }
3294
3295 fn non_mapped_size_in_bytes(
3296 &self,
3297 config: &str,
3298 dtype: DType,
3299 weight_pack_factor: usize,
3300 ) -> Result<usize> {
3301 let cfg: Gemma3Config = serde_json::from_str(config)?;
3302
3303 let text_elems = {
3304 let cfg = match &cfg {
3305 Gemma3Config::Text(cfg) => cfg,
3306 Gemma3Config::WithVision { text_config, .. } => text_config,
3307 };
3308 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3309 let lm_head = if !cfg.tie_word_embeddings {
3310 cfg.hidden_size * cfg.vocab_size
3311 } else {
3312 0
3313 };
3314 let norm = cfg.hidden_size;
3315 embed_tokens + lm_head + norm
3316 };
3317
3318 let vision_transformer = if let Gemma3Config::WithVision {
3319 vision_config: cfg, ..
3320 } = &cfg
3321 {
3322 let post_layernorm = cfg.hidden_size;
3323
3324 let conv_config = Conv2dConfig {
3325 stride: cfg.patch_size,
3326 ..Default::default()
3327 };
3328 let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
3329 * cfg.patch_size
3330 * cfg.patch_size;
3331
3332 let num_patches_per_side = cfg.image_size / cfg.patch_size;
3333 let num_patches = num_patches_per_side.pow(2);
3334 let position_embedding = num_patches * cfg.hidden_size;
3335
3336 let layer_elems = {
3337 let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3338 let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3339
3340 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
3341 let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
3342
3343 let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3344 let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3345 let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3346 let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3347
3348 layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
3349 };
3350
3351 post_layernorm
3352 + patch_embedding
3353 + position_embedding
3354 + layer_elems * cfg.num_hidden_layers
3355 } else {
3356 0
3357 };
3358
3359 let elems = text_elems + vision_transformer;
3360
3361 Ok(elems * dtype.size_in_bytes())
3362 }
3363
3364 fn layer_sizes_in_bytes(
3365 &self,
3366 config: &str,
3367 dtype: DType,
3368 weight_pack_factor: usize,
3369 ) -> Result<Vec<usize>> {
3370 let cfg: Gemma3Config = serde_json::from_str(config)?;
3371
3372 let txt_cfg = match &cfg {
3373 Gemma3Config::Text(cfg) => cfg,
3374 Gemma3Config::WithVision { text_config, .. } => text_config,
3375 };
3376 let per_layer_elems = {
3377 let cfg = txt_cfg;
3378
3379 let input_layernorm = cfg.hidden_size;
3380 let post_attention_layernorm = cfg.hidden_size;
3381
3382 let size_in = cfg.hidden_size;
3383 let size_q = cfg.head_dim * cfg.num_attention_heads;
3384 let size_kv = cfg.head_dim * cfg.num_key_value_heads;
3385 let q_proj =
3386 size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
3387 let k_proj =
3388 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
3389 let v_proj =
3390 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
3391 let o_proj =
3392 size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
3393
3394 let h_size = cfg.hidden_size;
3395 let i_size = cfg.intermediate_size;
3396 let gate_proj = h_size * i_size / weight_pack_factor;
3397 let up_proj = h_size * i_size / weight_pack_factor;
3398 let down_proj = i_size * h_size / weight_pack_factor;
3399
3400 input_layernorm
3401 + post_attention_layernorm
3402 + q_proj
3403 + k_proj
3404 + v_proj
3405 + o_proj
3406 + gate_proj
3407 + up_proj
3408 + down_proj
3409 };
3410 Ok(vec![
3411 per_layer_elems * dtype.size_in_bytes();
3412 txt_cfg.num_hidden_layers
3413 ])
3414 }
3415
3416 fn num_layers(&self, config: &str) -> Result<usize> {
3417 let cfg: Gemma3Config = serde_json::from_str(config)?;
3418
3419 let txt_cfg = match &cfg {
3420 Gemma3Config::Text(cfg) => cfg,
3421 Gemma3Config::WithVision { text_config, .. } => text_config,
3422 };
3423
3424 Ok(txt_cfg.num_hidden_layers)
3425 }
3426
3427 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3428 let cfg: Gemma3Config = serde_json::from_str(config)?;
3429
3430 let cfg = match &cfg {
3431 Gemma3Config::Text(cfg) => cfg,
3432 Gemma3Config::WithVision { text_config, .. } => text_config,
3433 };
3434
3435 let cfg = ModelConfigMetadata {
3436 max_seq_len: cfg.max_position_embeddings,
3437 num_layers: cfg.num_hidden_layers,
3438 hidden_size: cfg.hidden_size,
3439 num_kv_heads: cfg.num_key_value_heads,
3440 num_attn_heads: cfg.num_attention_heads,
3441 sliding_window: None, k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3443 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3444 };
3445
3446 Ok(Box::new(cfg))
3447 }
3448
3449 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
3450 Some(vec![NonMappedSubModel::Vision])
3451 }
3452}
3453
3454pub struct Mistral3Loader;
3460
3461pub struct Mistral3Prefixer;
3462
3463impl VisionPromptPrefixer for Mistral3Prefixer {
3464 fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
3465 prompt.to_string()
3466 }
3467}
3468
3469impl VisionModelLoader for Mistral3Loader {
3470 fn load(
3471 &self,
3472 config: &str,
3473 use_flash_attn: bool,
3474 vb: ShardedVarBuilder,
3475 normal_loading_metadata: NormalLoadingMetadata,
3476 attention_mechanism: AttentionImplementation,
3477 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
3478 let mut config: Mistral3Config = serde_json::from_str(config)?;
3479 config.text_config.use_flash_attn = use_flash_attn;
3480 Ok(Box::new(Mistral3Model::new(
3481 &config,
3482 vb,
3483 self.is_gptx(),
3484 normal_loading_metadata,
3485 attention_mechanism,
3486 )?))
3487 }
3488 fn is_gptx(&self) -> bool {
3489 true
3490 }
3491 fn get_config_repr(&self, config: &str, _use_flash_attn: bool) -> Result<Box<dyn Debug>> {
3492 let config: Mistral3Config = serde_json::from_str(config)?;
3493 Ok(Box::new(config))
3494 }
3495 fn get_processor(
3496 &self,
3497 _model_config: &str,
3498 processor_config: Option<ProcessorConfig>,
3499 _preprocessor_config: PreProcessorConfig,
3500 _max_edge: Option<u32>,
3501 ) -> Arc<dyn Processor + Send + Sync> {
3502 Arc::new(Mistral3Processor::new(processor_config.unwrap_or_default()))
3503 }
3504 fn supports_paged_attention(&self) -> bool {
3505 true
3506 }
3507 fn prefixer(&self) -> Arc<dyn VisionPromptPrefixer> {
3508 Arc::new(Mistral3Prefixer)
3509 }
3510}
3511
3512impl IsqModelLoader for Mistral3Loader {
3513 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3514 Ok(vec![
3515 Regex::new(r"lm_head\.(weight|bias)$")?,
3516 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3518 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3519 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3520 Regex::new(r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$")?,
3521 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3523 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3524 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3525 ])
3526 }
3527}
3528
3529#[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
3530impl DeviceMappedModelLoader for Mistral3Loader {
3531 fn mapped_max_act_size_elems(
3532 &self,
3533 config: &str,
3534 params: &AutoDeviceMapParams,
3535 _prompt_chunksize: usize,
3536 ) -> Result<usize> {
3537 let cfg: Mistral3Config = serde_json::from_str(config)?;
3538 let vcfg = &cfg.vision_config;
3539 let tcfg = &cfg.text_config;
3540
3541 let AutoDeviceMapParams::Vision {
3542 max_seq_len,
3543 max_batch_size,
3544 max_image_shape: (mut height, mut width),
3545 max_num_images,
3546 } = params
3547 else {
3548 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3549 };
3550
3551 let img_seq_len = {
3552 let (max_height, max_width) = (1540, 1540);
3556 let ratio = (height as f64 / max_height as f64).max(width as f64 / max_width as f64);
3557 if ratio > 1. {
3558 height = (height as f64 / ratio).floor() as usize;
3559 width = (width as f64 / ratio).floor() as usize;
3560 }
3561
3562 let num_height_tokens = (height - 1) / vcfg.patch_size + 1;
3563 let num_width_tokens = (width - 1) / vcfg.patch_size + 1;
3564
3565 height = num_height_tokens * vcfg.patch_size;
3566 width = num_width_tokens * vcfg.patch_size;
3567
3568 let num_height_tokens = height / vcfg.patch_size;
3569 let num_width_tokens = width / vcfg.patch_size;
3570
3571 (num_width_tokens + 1) * num_height_tokens
3572 };
3573
3574 let max_seq_len = img_seq_len * max_num_images + *max_seq_len;
3576 Ok(max_batch_size * tcfg.num_attention_heads * max_seq_len * max_seq_len)
3577 }
3578
3579 fn non_mapped_max_act_size_elems(
3580 &self,
3581 config: &str,
3582 params: &AutoDeviceMapParams,
3583 ) -> Result<usize> {
3584 let cfg: Mistral3Config = serde_json::from_str(config)?;
3585 let cfg = &cfg.vision_config;
3586
3587 let AutoDeviceMapParams::Vision {
3588 max_seq_len: _,
3589 max_batch_size,
3590 max_image_shape: (mut height, mut width),
3591 max_num_images,
3592 } = params
3593 else {
3594 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3595 };
3596
3597 let img_seq_len = {
3598 let (max_height, max_width) = (1540, 1540);
3602 let ratio = (height as f64 / max_height as f64).max(width as f64 / max_width as f64);
3603 if ratio > 1. {
3604 height = (height as f64 / ratio).floor() as usize;
3605 width = (width as f64 / ratio).floor() as usize;
3606 }
3607
3608 let num_height_tokens = (height - 1) / cfg.patch_size + 1;
3609 let num_width_tokens = (width - 1) / cfg.patch_size + 1;
3610
3611 height = num_height_tokens * cfg.patch_size;
3612 width = num_width_tokens * cfg.patch_size;
3613
3614 let num_height_tokens = height / cfg.patch_size;
3615 let num_width_tokens = width / cfg.patch_size;
3616
3617 (num_width_tokens + 1) * num_height_tokens
3618 };
3619
3620 Ok((max_batch_size * max_num_images) * cfg.num_attention_heads * img_seq_len * img_seq_len)
3621 }
3622
3623 fn non_mapped_size_in_bytes(
3624 &self,
3625 config: &str,
3626 dtype: DType,
3627 weight_pack_factor: usize,
3628 ) -> Result<usize> {
3629 let cfg: Mistral3Config = serde_json::from_str(config)?;
3630
3631 let text_elems = {
3632 let cfg = &cfg.text_config;
3633
3634 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3635 let lm_head = if !cfg.tie_word_embeddings {
3636 cfg.hidden_size * cfg.vocab_size
3637 } else {
3638 0
3639 };
3640 let norm = cfg.hidden_size;
3641 embed_tokens + lm_head + norm
3642 };
3643
3644 let vision_elems = {
3645 let cfg = &cfg.vision_config;
3646
3647 let patch_embed = {
3648 let conv_cfg = Conv2dConfig {
3649 stride: cfg.patch_size,
3650 ..Default::default()
3651 };
3652 cfg.num_channels * cfg.hidden_size / conv_cfg.groups
3653 * cfg.patch_size
3654 * cfg.patch_size
3655 * cfg.patch_size
3656 };
3657 let ln_pre = cfg.hidden_size;
3658 let vision_layer = {
3659 let attn_norm = cfg.hidden_size;
3660 let ffn_norm = cfg.hidden_size;
3661
3662 let gate = cfg.hidden_size * cfg.intermediate_size;
3663 let up = cfg.hidden_size * cfg.intermediate_size;
3664 let down = cfg.hidden_size * cfg.intermediate_size;
3665
3666 let q = cfg.hidden_size * cfg.hidden_size;
3667 let k = cfg.hidden_size * cfg.hidden_size;
3668 let v = cfg.hidden_size * cfg.hidden_size;
3669 let o = cfg.hidden_size * cfg.hidden_size;
3670
3671 attn_norm + ffn_norm + gate + up + down + q + k + v + o
3672 };
3673
3674 patch_embed + ln_pre + vision_layer * cfg.num_hidden_layers
3675 };
3676
3677 let elems = text_elems + vision_elems;
3678
3679 Ok(elems * dtype.size_in_bytes())
3680 }
3681
3682 fn layer_sizes_in_bytes(
3683 &self,
3684 config: &str,
3685 dtype: DType,
3686 weight_pack_factor: usize,
3687 ) -> Result<Vec<usize>> {
3688 let cfg: Mistral3Config = serde_json::from_str(config)?;
3689 let cfg = &cfg.text_config;
3690
3691 let per_layer_elems = {
3692 let input_layernorm = cfg.hidden_size;
3693 let post_attention_layernorm = cfg.hidden_size;
3694
3695 let size_in = cfg.hidden_size;
3696 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
3697 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
3698 let q_proj = size_in * size_q / weight_pack_factor;
3699 let k_proj = size_in * size_kv / weight_pack_factor;
3700 let v_proj = size_in * size_kv / weight_pack_factor;
3701 let o_proj = size_q * size_in / weight_pack_factor;
3702
3703 let h_size = cfg.hidden_size;
3704 let i_size = cfg.intermediate_size;
3705 let gate_proj = h_size * i_size / weight_pack_factor;
3706 let up_proj = h_size * i_size / weight_pack_factor;
3707 let down_proj = i_size * h_size / weight_pack_factor;
3708
3709 input_layernorm
3710 + post_attention_layernorm
3711 + q_proj
3712 + k_proj
3713 + v_proj
3714 + o_proj
3715 + gate_proj
3716 + up_proj
3717 + down_proj
3718 };
3719 Ok(vec![
3720 per_layer_elems * dtype.size_in_bytes();
3721 cfg.num_hidden_layers
3722 ])
3723 }
3724
3725 fn num_layers(&self, config: &str) -> Result<usize> {
3726 let cfg: Mistral3Config = serde_json::from_str(config)?;
3727 let cfg = &cfg.text_config;
3728 Ok(cfg.num_hidden_layers)
3729 }
3730
3731 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3732 let cfg: Mistral3Config = serde_json::from_str(config)?;
3733 let cfg = &cfg.text_config;
3734
3735 let cfg = ModelConfigMetadata {
3736 max_seq_len: cfg.max_position_embeddings,
3737 num_layers: cfg.num_hidden_layers,
3738 hidden_size: cfg.hidden_size,
3739 num_kv_heads: cfg.num_key_value_heads,
3740 num_attn_heads: cfg.num_attention_heads,
3741 sliding_window: cfg.sliding_window,
3742 k_head_dim: cfg.head_dim(),
3743 v_head_dim: cfg.head_dim(),
3744 };
3745
3746 Ok(Box::new(cfg))
3747 }
3748
3749 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
3750 Some(vec![NonMappedSubModel::Vision])
3751 }
3752}
3753
3754pub struct VLlama4Loader;
3760
3761pub struct VLlama4Prefixer;
3762
3763impl VisionPromptPrefixer for VLlama4Prefixer {
3764 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
3765 format!(
3766 "{}{prompt}",
3767 llama4::IMAGE_TOKEN.repeat(image_indexes.len())
3768 )
3769 }
3770}
3771
3772impl VisionModelLoader for VLlama4Loader {
3773 fn load(
3774 &self,
3775 config: &str,
3776 use_flash_attn: bool,
3777 vb: ShardedVarBuilder,
3778 normal_loading_metadata: NormalLoadingMetadata,
3779 attention_mechanism: AttentionImplementation,
3780 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
3781 let mut config: Llama4Config = serde_json::from_str(config)?;
3782 config.text_config.use_flash_attn = use_flash_attn;
3783 Ok(Box::new(Llama4Model::new(
3784 &config,
3785 vb,
3786 self.is_gptx(),
3787 normal_loading_metadata,
3788 attention_mechanism,
3789 )?))
3790 }
3791 fn is_gptx(&self) -> bool {
3792 false
3793 }
3794 fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
3795 let mut config: Llama4Config = serde_json::from_str(config)?;
3796 config.text_config.use_flash_attn = use_flash_attn;
3797 Ok(Box::new(config))
3798 }
3799 fn get_processor(
3800 &self,
3801 _model_config: &str,
3802 processor_config: Option<ProcessorConfig>,
3803 _preprocessor_config: PreProcessorConfig,
3804 _max_edge: Option<u32>,
3805 ) -> Arc<dyn Processor + Send + Sync> {
3806 Arc::new(Llama4Processor::new(&processor_config.unwrap()))
3807 }
3808 fn supports_paged_attention(&self) -> bool {
3809 true
3810 }
3811 fn prefixer(&self) -> Arc<dyn VisionPromptPrefixer> {
3812 Arc::new(VLlama4Prefixer)
3813 }
3814}
3815
3816impl IsqModelLoader for VLlama4Loader {
3817 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3818 Ok(vec![
3819 Regex::new(r"lm_head\.(weight|bias)$")?,
3820 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3822 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3823 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3824 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3825 Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.gate_up_proj\.(weight|bias)$")?,
3827 Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.gate_proj\.(weight|bias)$")?,
3828 Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.up_proj\.(weight|bias)$")?,
3829 Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.down_proj\.(weight|bias)$")?,
3830 Regex::new(r"layers\.(\d+)\.feed_forward\.router\.(weight|bias)$")?,
3831 Regex::new(r"layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$")?,
3832 Regex::new(r"layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$")?,
3833 Regex::new(r"layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$")?,
3834 Regex::new(r"layers\.(\d+)\.feed_forward\.gate_proj\.(weight|bias)$")?,
3836 Regex::new(r"layers\.(\d+)\.feed_forward\.up_proj\.(weight|bias)$")?,
3837 Regex::new(r"layers\.(\d+)\.feed_forward\.down_proj\.(weight|bias)$")?,
3838 ])
3839 }
3840}
3841
3842impl VLlama4Loader {
3843 #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
3846 fn run_dummy_processing(
3847 &self,
3848 cfg: &Llama4Config,
3849 height: usize,
3850 width: usize,
3851 max_num_images: usize,
3852 max_batch_size: usize,
3853 ) -> Result<(usize, usize)> {
3854 let cfg = &cfg.vision_config;
3855
3856 let img_processor =
3857 Llama4ImageProcessor::new(Some(cfg.patch_size), Some(cfg.pixel_shuffle_ratio));
3858 let image = DynamicImage::new(width as u32, height as u32, ColorType::Rgb8);
3859 let res = img_processor.preprocess(
3860 vec![image; max_num_images],
3861 vec![],
3862 &PreProcessorConfig::default(),
3863 &Device::Cpu,
3864 (max_batch_size, max_num_images),
3865 )?;
3866
3867 let pixels_batch_size = res.pixel_values.dim(0)?;
3868 let pixels_max_batch_size = pixels_batch_size * max_batch_size;
3869
3870 let (image_h, image_w) = (
3871 res.pixel_values.dim(D::Minus2).unwrap(),
3872 res.pixel_values.dim(D::Minus1).unwrap(),
3873 );
3874 let num_patches_per_chunk = (image_h / img_processor.patch_size)
3875 * (image_w / img_processor.patch_size)
3876 / img_processor.downsample_ratio;
3877
3878 Ok((
3879 pixels_max_batch_size,
3880 num_patches_per_chunk * pixels_max_batch_size,
3881 ))
3882 }
3883}
3884
3885impl DeviceMappedModelLoader for VLlama4Loader {
3886 fn mapped_max_act_size_elems(
3887 &self,
3888 config: &str,
3889 params: &AutoDeviceMapParams,
3890 _prompt_chunksize: usize,
3891 ) -> Result<usize> {
3892 let AutoDeviceMapParams::Vision {
3893 max_seq_len,
3894 max_batch_size,
3895 max_image_shape: (height, width),
3896 max_num_images,
3897 } = params
3898 else {
3899 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3900 };
3901
3902 let cfg: Llama4Config = serde_json::from_str(config)?;
3903
3904 let (_pixels_batch_size, num_text_image_toks) =
3905 self.run_dummy_processing(&cfg, *height, *width, *max_num_images, *max_batch_size)?;
3906
3907 let max_seq_len = max_seq_len + num_text_image_toks;
3908
3909 Ok(max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len)
3910 }
3911 fn non_mapped_max_act_size_elems(
3912 &self,
3913 config: &str,
3914 params: &AutoDeviceMapParams,
3915 ) -> Result<usize> {
3916 let AutoDeviceMapParams::Vision {
3917 max_seq_len: _,
3918 max_batch_size,
3919 max_image_shape: (height, width),
3920 max_num_images,
3921 } = params
3922 else {
3923 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3924 };
3925
3926 let cfg: Llama4Config = serde_json::from_str(config)?;
3927
3928 let (pixels_batch_size, _num_text_image_toks) =
3929 self.run_dummy_processing(&cfg, *height, *width, *max_num_images, *max_batch_size)?;
3930 let max_seq_len = cfg.vision_config.num_patches();
3931
3932 Ok((max_batch_size * pixels_batch_size)
3933 * cfg.vision_config.num_attention_heads
3934 * max_seq_len
3935 * max_seq_len)
3936 }
3937
3938 fn non_mapped_size_in_bytes(
3939 &self,
3940 config: &str,
3941 dtype: DType,
3942 weight_pack_factor: usize,
3943 ) -> Result<usize> {
3944 let cfg: Llama4Config = serde_json::from_str(config)?;
3945 let tcfg = &cfg.text_config;
3946
3947 let text_elems = {
3948 let embed_tokens = tcfg.hidden_size * tcfg.vocab_size / weight_pack_factor;
3949 let lm_head = if !tcfg.tie_word_embeddings {
3950 tcfg.hidden_size * tcfg.vocab_size
3951 } else {
3952 0
3953 };
3954 let norm = tcfg.hidden_size;
3955 embed_tokens + lm_head + norm
3956 };
3957
3958 let vision_elems = {
3959 let cfg = &cfg.vision_config;
3960
3961 let num_patches = cfg.num_patches();
3962
3963 let unfold_elems =
3964 (cfg.num_channels * cfg.patch_size * cfg.patch_size) * cfg.hidden_size;
3965 let class_embeddng_elems = cfg.hidden_size;
3966 let positional_embedding_vlm_elems = num_patches * cfg.hidden_size;
3967 let layernorm_pre_elems = cfg.hidden_size;
3968 let layernorm_post_elems = cfg.hidden_size;
3969
3970 let pixel_shuffle_elems = cfg.intermediate_size * cfg.projector_input_dim
3971 / weight_pack_factor
3972 + cfg.projector_input_dim * cfg.projector_output_dim / weight_pack_factor;
3973
3974 let encoder_layer = {
3975 let input_layernorm = cfg.hidden_size + cfg.hidden_size;
3976 let post_attention_layernorm = cfg.hidden_size + cfg.hidden_size;
3977
3978 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
3979 let q_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
3980 / weight_pack_factor
3981 + cfg.num_attention_heads * head_dim;
3982 let k_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
3983 / weight_pack_factor
3984 + cfg.num_attention_heads * head_dim;
3985 let v_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
3986 / weight_pack_factor
3987 + cfg.num_attention_heads * head_dim;
3988 let o_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
3989 / weight_pack_factor
3990 + cfg.num_attention_heads * head_dim;
3991
3992 let fc1 = (cfg.hidden_size * cfg.intermediate_size) / weight_pack_factor
3993 + cfg.intermediate_size;
3994 let fc2 = (cfg.intermediate_size * cfg.hidden_size) / weight_pack_factor
3995 + cfg.hidden_size;
3996
3997 input_layernorm
3998 + post_attention_layernorm
3999 + q_proj
4000 + k_proj
4001 + v_proj
4002 + o_proj
4003 + fc1
4004 + fc2
4005 };
4006
4007 unfold_elems
4008 + class_embeddng_elems
4009 + positional_embedding_vlm_elems
4010 + layernorm_post_elems
4011 + layernorm_pre_elems
4012 + pixel_shuffle_elems
4013 + encoder_layer * cfg.num_hidden_layers
4014 };
4015
4016 let elems = text_elems + vision_elems;
4017
4018 Ok(elems * dtype.size_in_bytes())
4019 }
4020
4021 fn layer_sizes_in_bytes(
4022 &self,
4023 config: &str,
4024 dtype: DType,
4025 weight_pack_factor: usize,
4026 ) -> Result<Vec<usize>> {
4027 let cfg: Llama4Config = serde_json::from_str(config)?;
4028 let tcfg = &cfg.text_config;
4029
4030 let mut per_layer_elems = Vec::new();
4031
4032 for layer_idx in 0..tcfg.num_hidden_layers {
4033 let input_layernorm = tcfg.hidden_size;
4034 let post_attention_layernorm = tcfg.hidden_size;
4035
4036 let size_in = tcfg.hidden_size;
4037 let size_q = (tcfg.hidden_size / tcfg.num_attention_heads) * tcfg.num_attention_heads;
4038 let size_kv = (tcfg.hidden_size / tcfg.num_attention_heads) * tcfg.num_key_value_heads;
4039 let q_proj = size_in * size_q / weight_pack_factor;
4040 let k_proj = size_in * size_kv / weight_pack_factor;
4041 let v_proj = size_in * size_kv / weight_pack_factor;
4042 let o_proj = size_q * size_in / weight_pack_factor;
4043
4044 let use_moe = tcfg.moe_layers().contains(&layer_idx);
4045 let moe_block = if use_moe {
4046 let h_size = tcfg.hidden_size;
4047 let i_size = tcfg.intermediate_size;
4048 let gate_proj = tcfg.num_local_experts * h_size * i_size / weight_pack_factor;
4049 let up_proj = tcfg.num_local_experts * h_size * i_size / weight_pack_factor;
4050 let down_proj = tcfg.num_local_experts * i_size * h_size / weight_pack_factor;
4051
4052 gate_proj + up_proj + down_proj
4053 } else {
4054 let h_size = tcfg.hidden_size;
4055 let i_size = tcfg.intermediate_size_mlp;
4056 let gate_proj = h_size * i_size / weight_pack_factor;
4057 let up_proj = h_size * i_size / weight_pack_factor;
4058 let down_proj = i_size * h_size / weight_pack_factor;
4059
4060 gate_proj + up_proj + down_proj
4061 };
4062
4063 per_layer_elems.push(
4064 input_layernorm
4065 + post_attention_layernorm
4066 + q_proj
4067 + k_proj
4068 + v_proj
4069 + o_proj
4070 + moe_block,
4071 );
4072 }
4073
4074 Ok(per_layer_elems
4075 .into_iter()
4076 .map(|x| x * dtype.size_in_bytes())
4077 .collect())
4078 }
4079
4080 fn num_layers(&self, config: &str) -> Result<usize> {
4081 let cfg: Llama4Config = serde_json::from_str(config)?;
4082 Ok(cfg.text_config.num_hidden_layers)
4083 }
4084
4085 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4086 let cfg: Llama4Config = serde_json::from_str(config)?;
4087 let cfg = &cfg.text_config;
4088
4089 let cfg = ModelConfigMetadata {
4090 max_seq_len: cfg.max_position_embeddings,
4091 num_layers: cfg.num_hidden_layers,
4092 hidden_size: cfg.hidden_size,
4093 num_kv_heads: cfg.num_attention_heads,
4094 num_attn_heads: cfg.num_attention_heads,
4095 sliding_window: None,
4096 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
4097 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
4098 };
4099
4100 Ok(Box::new(cfg))
4101 }
4102
4103 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
4104 Some(vec![NonMappedSubModel::Vision])
4105 }
4106}