1use std::any::Any;
2use std::sync::Arc;
3use std::{fmt::Debug, str::FromStr};
4
5use anyhow::Result;
6use candle_core::{DType, Device, Tensor};
7use candle_nn::Conv2dConfig;
8use mistralrs_quant::ShardedVarBuilder;
9
10#[cfg(feature = "pyo3_macros")]
11use pyo3::pyclass;
12
13use regex::Regex;
14use serde::Deserialize;
15
16use self::minicpmo::{MiniCpmOConfig, MiniCpmOModel, MiniCpmOProcessor};
17
18use super::{DeviceMappedModelLoader, NonMappedSubModel, NormalLoadingMetadata};
19use crate::amoe::AnyMoeBaseModelMixin;
20use crate::device_map::DeviceMapper;
21use crate::layers::Conv3dConfig;
22use crate::paged_attention::{AttentionImplementation, ModelConfigLike, ModelConfigMetadata};
23use crate::pipeline::isq::IsqModelLoader;
24use crate::pipeline::loaders::AutoDeviceMapParams;
25use crate::pipeline::text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata};
26use crate::pipeline::{EitherCache, IsqModel, Processor, ProcessorCreator, VisionPromptPrefixer};
27use crate::utils::varbuilder_utils::DeviceForLoadTensor;
28use crate::vision_models::clip::ClipConfig;
29use crate::vision_models::gemma3::config::Gemma3Config;
30use crate::vision_models::gemma3::{Gemma3Model, Gemma3Processor};
31use crate::vision_models::idefics2::{Config as Idefics2Config, Idefics2};
32use crate::vision_models::idefics2_input_processor::Idefics2Processor;
33use crate::vision_models::idefics3::{Idefics3Config, Idefics3Model, Idefics3Processor};
34use crate::vision_models::inputs_processor::Phi4MMProcessor;
35use crate::vision_models::llava::config::Config as LLaVAConfig;
36use crate::vision_models::llava15::Model as LLaVA;
37use crate::vision_models::llava_inputs_processor::{self, LLaVAProcessor};
38use crate::vision_models::llava_next::Model as LLaVANext;
39use crate::vision_models::llava_next_inputs_processor::{self, LLaVANextProcessor};
40use crate::vision_models::mistral3::{Mistral3Config, Mistral3Model, Mistral3Processor};
41use crate::vision_models::mllama::{MLlamaConfig, MLlamaModel, MLlamaProcessor};
42use crate::vision_models::phi3::{Config as Phi3Config, Model as Phi3, PHI3V_CLIP_CONFIG};
43use crate::vision_models::phi3_inputs_processor::Phi3Processor;
44use crate::vision_models::phi4::{Phi4MMConfig, Phi4MMModel, PHI4_MM_VISION_CFG};
45use crate::vision_models::preprocessor_config::PreProcessorConfig;
46use crate::vision_models::processor_config::ProcessorConfig;
47use crate::vision_models::qwen2_5_vl::{
48 Config as Qwen2_5VLConfig, Qwen2_5VLModel, Qwen2_5VLProcessor,
49};
50use crate::vision_models::qwen2vl::{Config as Qwen2VLConfig, Qwen2VLModel, Qwen2VLProcessor};
51use crate::vision_models::{minicpmo, phi4};
52
53pub trait VisionModel: IsqModel + AnyMoeBaseModelMixin {
54 #[allow(clippy::too_many_arguments)]
56 fn forward(
57 &self,
58 input_ids: &Tensor,
59 pixel_values: Option<Tensor>,
60 seqlen_offsets: &[usize],
61 context_lens: Vec<(usize, usize)>,
62 position_ids: Vec<usize>,
63 model_specific_args: Box<dyn Any>, metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
65 flash_params: &FlashParams,
66 ) -> candle_core::Result<Tensor>;
67 fn device(&self) -> &Device;
68 fn cache(&self) -> &EitherCache;
69 fn cache_mut(&mut self) -> &mut EitherCache;
70 fn max_seq_len(&self) -> usize;
71 fn has_conv2d(&self) -> bool;
72 fn config(&self) -> &ModelConfigMetadata;
73 fn default_model_specific_args(&self, input_ids: &Tensor) -> Box<dyn Any>;
75}
76
77pub trait VisionModelLoader: IsqModelLoader + Send + Sync + DeviceMappedModelLoader {
78 fn load(
79 &self,
80 config: &str,
81 use_flash_attn: bool,
82 vb: ShardedVarBuilder,
83 normal_loading_metadata: NormalLoadingMetadata,
84 attention_mechanism: AttentionImplementation,
85 ) -> Result<Box<dyn VisionModel + Send + Sync>>;
86 fn is_gptx(&self) -> bool;
87 fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>>;
88 fn get_processor(
89 &self,
90 model_config: &str,
91 processor_config: Option<ProcessorConfig>,
92 preprocessor_config: PreProcessorConfig,
93 max_edge: Option<u32>,
94 ) -> Arc<dyn Processor + Send + Sync>;
95 fn supports_paged_attention(&self) -> bool;
96 fn prefixer(&self) -> Arc<dyn VisionPromptPrefixer>;
97 fn get_device_for_tensor(
98 &self,
99 config: &str,
100 _mapper: &dyn DeviceMapper,
101 loading_isq: bool,
102 ) -> Result<Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>> {
103 if loading_isq {
104 Ok(Arc::new(|_| DeviceForLoadTensor::Base))
105 } else {
106 let re = Regex::new(r"\.layers\.(\d+)\.").unwrap();
107 let num_layers = self.model_config(config)?.num_layers();
108 let closure = move |name: String| {
109 if let Some(captures) = re.captures(&name) {
110 captures
111 .get(1)
112 .and_then(|m| m.as_str().parse::<usize>().ok())
113 .map(|l| l.min(num_layers))
114 .map(DeviceForLoadTensor::Idx)
115 .unwrap_or(DeviceForLoadTensor::Base)
116 } else {
117 DeviceForLoadTensor::Base
118 }
119 };
120
121 Ok(Arc::new(closure))
122 }
123 }
124}
125
126#[cfg_attr(feature = "pyo3_macros", pyclass(eq, eq_int))]
127#[derive(Clone, Debug, Deserialize, PartialEq)]
128pub enum VisionLoaderType {
130 #[serde(rename = "phi3v")]
131 Phi3V,
132 #[serde(rename = "idefics2")]
133 Idefics2,
134 #[serde(rename = "llava_next")]
135 LLaVANext,
136 #[serde(rename = "llava")]
137 LLaVA,
138 #[serde(rename = "vllama")]
139 VLlama,
140 #[serde(rename = "qwen2vl")]
141 Qwen2VL,
142 #[serde(rename = "idefics3")]
143 Idefics3,
144 #[serde(rename = "minicpmo")]
145 MiniCpmO,
146 #[serde(rename = "phi4mm")]
147 Phi4MM,
148 #[serde(rename = "qwen2_5vl")]
149 Qwen2_5VL,
150 #[serde(rename = "gemma3")]
151 Gemma3,
152 #[serde(rename = "mistral3")]
153 Mistral3,
154}
155
156impl FromStr for VisionLoaderType {
157 type Err = String;
158 fn from_str(s: &str) -> Result<Self, Self::Err> {
159 match s {
160 "phi3v" => Ok(Self::Phi3V),
161 "idefics2" => Ok(Self::Idefics2),
162 "llava_next" => Ok(Self::LLaVANext),
163 "llava" => Ok(Self::LLaVA),
164 "vllama" => Ok(Self::VLlama),
165 "qwen2vl" => Ok(Self::Qwen2VL),
166 "idefics3" => Ok(Self::Idefics3),
167 "minicpmo" => Ok(Self::MiniCpmO),
168 "phi4mm" => Ok(Self::Phi4MM),
169 "qwen2_5vl" => Ok(Self::Qwen2_5VL),
170 "gemma3" => Ok(Self::Gemma3),
171 "mistral3" => Ok(Self::Mistral3),
172 a => Err(format!("Unknown architecture `{a}`. Possible architectures: `phi3v`, `idefics2`, `llava_next`, `llava`, `vllama`, `qwen2vl`, `idefics3`, `minicpmo`, `phi4mm`, `qwen2_5vl`, `gemma3`, `mistral3`.")),
173 }
174 }
175}
176
177macro_rules! bias_if {
178 ($cond:expr, $size:expr) => {
179 if $cond {
180 $size
181 } else {
182 0
183 }
184 };
185}
186
187fn get_clip_vit_num_elems(cfg: &ClipConfig) -> usize {
188 let pre_layer_norm = cfg.hidden_size;
189 let final_layer_norm = cfg.hidden_size;
190
191 let num_patches = (cfg.image_size / cfg.patch_size).pow(2);
192 let num_positions = num_patches + 1;
193
194 let class_embedding = cfg.hidden_size;
195
196 let position_ids = num_positions;
197 let position_embedding = num_positions * cfg.hidden_size;
198
199 let conv2dconfig = Conv2dConfig {
200 stride: cfg.patch_size,
201 ..Default::default()
202 };
203 let patch_embedding =
204 cfg.num_channels * cfg.hidden_size / conv2dconfig.groups * cfg.patch_size * cfg.patch_size;
205
206 let encoder_layer_elems = {
207 let layer_norm1 = cfg.hidden_size;
208 let layer_norm2 = cfg.hidden_size;
209
210 let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
211 let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
212 let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
213 let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
214
215 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
216 let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
217
218 layer_norm1 + layer_norm2 + q_proj + k_proj + v_proj + o_proj + fc1 + fc2
219 };
220
221 pre_layer_norm
222 + final_layer_norm
223 + class_embedding
224 + position_ids
225 + position_embedding
226 + patch_embedding
227 + cfg.num_hidden_layers * encoder_layer_elems
228}
229
230pub struct Phi3VLoader;
236
237pub struct Phi3VPrefixer;
238
239impl VisionPromptPrefixer for Phi3VPrefixer {
240 fn prefix_image(&self, image_index: usize, prompt: &str) -> String {
241 format!("<|image_{}|>{prompt}", image_index + 1)
243 }
244}
245
246impl VisionModelLoader for Phi3VLoader {
247 fn load(
248 &self,
249 config: &str,
250 use_flash_attn: bool,
251 vb: ShardedVarBuilder,
252 normal_loading_metadata: NormalLoadingMetadata,
253 attention_mechanism: AttentionImplementation,
254 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
255 let mut config: Phi3Config = serde_json::from_str(config)?;
256 config.use_flash_attn = use_flash_attn;
257 Ok(Box::new(Phi3::new(
258 &config,
259 vb,
260 self.is_gptx(),
261 normal_loading_metadata,
262 attention_mechanism,
263 )?))
264 }
265 fn is_gptx(&self) -> bool {
266 true
267 }
268 fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
269 let mut config: Phi3Config = serde_json::from_str(config)?;
270 config.use_flash_attn = use_flash_attn;
271 Ok(Box::new(config))
272 }
273 fn get_processor(
274 &self,
275 _model_config: &str,
276 processor_config: Option<ProcessorConfig>,
277 preprocessor_config: PreProcessorConfig,
278 _max_edge: Option<u32>,
279 ) -> Arc<dyn Processor + Send + Sync> {
280 Phi3Processor::new_processor(processor_config, preprocessor_config)
281 }
282 fn supports_paged_attention(&self) -> bool {
283 true
284 }
285 fn prefixer(&self) -> Arc<dyn VisionPromptPrefixer> {
286 Arc::new(Phi3VPrefixer)
287 }
288}
289
290impl IsqModelLoader for Phi3VLoader {
291 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
292 Ok(vec![
293 Regex::new(r"lm_head\.(weight|bias)$")?,
294 Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?,
296 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
297 Regex::new(r"layers\.(\d+)\.mlp\.gate_up_proj\.(weight|bias)$")?,
299 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
300 ])
301 }
302}
303
304impl DeviceMappedModelLoader for Phi3VLoader {
305 fn mapped_max_act_size_elems(
306 &self,
307 config: &str,
308 params: &AutoDeviceMapParams,
309 _prompt_chunksize: usize,
310 ) -> Result<usize> {
311 let AutoDeviceMapParams::Vision {
313 max_seq_len,
314 max_batch_size,
315 max_image_shape: _,
316 max_num_images,
317 } = params
318 else {
319 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
320 };
321
322 let cfg: Phi3Config = serde_json::from_str(config)?;
323
324 let vcfg = &PHI3V_CLIP_CONFIG;
325
326 let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
327 let img_seq_len = (num_patches + 1) * max_num_images;
328
329 let max_text_attn = {
330 let max_seq_len = img_seq_len + max_seq_len;
332 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
333 };
334
335 Ok(max_text_attn)
336 }
337
338 fn non_mapped_max_act_size_elems(
339 &self,
340 config: &str,
341 params: &AutoDeviceMapParams,
342 ) -> Result<usize> {
343 let AutoDeviceMapParams::Vision {
345 max_seq_len: _,
346 max_batch_size,
347 max_image_shape: _,
348 max_num_images,
349 } = params
350 else {
351 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
352 };
353
354 let cfg: Phi3Config = serde_json::from_str(config)?;
355
356 let vcfg = &PHI3V_CLIP_CONFIG;
357
358 let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
359 let img_seq_len = num_patches + 1;
360
361 let max_vision_attn = {
362 (max_batch_size * max_num_images) * cfg.num_attention_heads * img_seq_len * img_seq_len
363 };
364
365 Ok(max_vision_attn)
366 }
367
368 fn non_mapped_size_in_bytes(
369 &self,
370 config: &str,
371 dtype: DType,
372 weight_pack_factor: usize,
373 ) -> Result<usize> {
374 let cfg: Phi3Config = serde_json::from_str(config)?;
375 let elems = {
376 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
377 let lm_head = if !cfg.tie_word_embeddings {
378 cfg.hidden_size * cfg.vocab_size
379 } else {
380 0
381 };
382 let norm = cfg.hidden_size;
383
384 let image_embed = {
385 let projection_cls = cfg
386 .embd_layer
387 .projection_cls
388 .clone()
389 .unwrap_or("linear".to_string());
390 let with_learnable_separator =
391 cfg.embd_layer.with_learnable_separator.unwrap_or(false);
392 let use_hd_transform = cfg.embd_layer.use_hd_transform.unwrap_or(false);
393 let image_dim_out = cfg.img_processor.image_dim_out;
394
395 let proj = match (projection_cls.as_str(), use_hd_transform) {
396 ("linear", _) => image_dim_out * cfg.hidden_size + cfg.hidden_size,
397 ("mlp", true) => {
398 let a = (image_dim_out * 4) * cfg.hidden_size + cfg.hidden_size;
399 let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
400 a + b
401 }
402 ("mlp", false) => {
403 let a = image_dim_out * cfg.hidden_size + cfg.hidden_size;
404 let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
405 a + b
406 }
407 _ => {
408 anyhow::bail!("projection_cls=`{projection_cls}` not implemented.");
409 }
410 };
411
412 let (glb_gn, sub_gn) = if with_learnable_separator {
413 let glb_gn = image_dim_out * 4;
414 let sub_gn = image_dim_out * 4;
415 (glb_gn, sub_gn)
416 } else {
417 (0, 0)
418 };
419
420 let clip_vit = get_clip_vit_num_elems(&PHI3V_CLIP_CONFIG);
421
422 proj + glb_gn + sub_gn + clip_vit
423 };
424
425 embed_tokens + lm_head + norm + image_embed
426 };
427
428 Ok(elems * dtype.size_in_bytes())
429 }
430
431 fn layer_sizes_in_bytes(
432 &self,
433 config: &str,
434 dtype: DType,
435 weight_pack_factor: usize,
436 ) -> Result<Vec<usize>> {
437 let cfg: Phi3Config = serde_json::from_str(config)?;
438 let per_layer_elems = {
439 let input_layernorm = cfg.hidden_size;
440 let post_attention_layernorm = cfg.hidden_size;
441
442 let size_in = cfg.hidden_size;
443 let head_dim = cfg.head_dim();
444 let op_size = head_dim * head_dim + 2 * cfg.num_key_value_heads * head_dim;
445 let qkv_proj = size_in * op_size / weight_pack_factor;
446 let o_proj = (cfg.num_attention_heads * head_dim) * size_in / weight_pack_factor;
447
448 let h_size = cfg.hidden_size;
449 let i_size = cfg.intermediate_size;
450 let gate_up_proj = h_size * (2 * i_size) / weight_pack_factor;
451 let down_proj = h_size * i_size / weight_pack_factor;
452
453 input_layernorm
454 + post_attention_layernorm
455 + qkv_proj
456 + o_proj
457 + gate_up_proj
458 + down_proj
459 };
460 Ok(vec![
461 per_layer_elems * dtype.size_in_bytes();
462 cfg.num_hidden_layers
463 ])
464 }
465
466 fn num_layers(&self, config: &str) -> Result<usize> {
467 let cfg: Phi3Config = serde_json::from_str(config)?;
468 Ok(cfg.num_hidden_layers)
469 }
470
471 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
472 let cfg: Phi3Config = serde_json::from_str(config)?;
473
474 let cfg = ModelConfigMetadata {
475 max_seq_len: cfg.max_position_embeddings,
476 num_layers: cfg.num_hidden_layers,
477 hidden_size: cfg.hidden_size,
478 num_kv_heads: cfg.num_key_value_heads,
479 num_attn_heads: cfg.num_attention_heads,
480 sliding_window: cfg.sliding_window,
481 k_head_dim: cfg.head_dim(),
482 v_head_dim: cfg.head_dim(),
483 };
484
485 Ok(Box::new(cfg))
486 }
487
488 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
489 Some(vec![NonMappedSubModel::Vision])
490 }
491}
492
493pub struct Idefics2Loader;
499
500pub struct Idefics2Prefixer;
501
502impl VisionPromptPrefixer for Idefics2Prefixer {
503 fn prefix_image(&self, _image_index: usize, prompt: &str) -> String {
504 prompt.to_string()
506 }
507}
508
509impl VisionModelLoader for Idefics2Loader {
510 fn load(
511 &self,
512 config: &str,
513 use_flash_attn: bool,
514 vb: ShardedVarBuilder,
515 normal_loading_metadata: NormalLoadingMetadata,
516 attention_mechanism: AttentionImplementation,
517 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
518 let mut config: Idefics2Config = serde_json::from_str(config)?;
519 config.text_config.use_flash_attn = use_flash_attn;
520 Ok(Box::new(Idefics2::new(
521 &config,
522 vb,
523 self.is_gptx(),
524 normal_loading_metadata,
525 attention_mechanism,
526 )?))
527 }
528 fn is_gptx(&self) -> bool {
529 true
530 }
531 fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
532 let mut config: Idefics2Config = serde_json::from_str(config)?;
533 config.text_config.use_flash_attn = use_flash_attn;
534 Ok(Box::new(config))
535 }
536 fn get_processor(
537 &self,
538 _model_config: &str,
539 processor_config: Option<ProcessorConfig>,
540 preprocessor_config: PreProcessorConfig,
541 max_edge: Option<u32>,
542 ) -> Arc<dyn Processor + Send + Sync> {
543 Arc::new(Idefics2Processor::new(
544 processor_config.unwrap(),
545 preprocessor_config,
546 max_edge,
547 ))
548 }
549 fn supports_paged_attention(&self) -> bool {
550 true
551 }
552 fn prefixer(&self) -> Arc<dyn VisionPromptPrefixer> {
553 Arc::new(Idefics2Prefixer)
554 }
555}
556
557impl IsqModelLoader for Idefics2Loader {
558 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
559 Ok(vec![
560 Regex::new(r"lm_head\.(weight|bias)$")?,
561 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
563 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
564 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
565 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
566 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
568 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
569 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
570 ])
571 }
572}
573
574impl DeviceMappedModelLoader for Idefics2Loader {
575 fn mapped_max_act_size_elems(
576 &self,
577 config: &str,
578 params: &AutoDeviceMapParams,
579 _prompt_chunksize: usize,
580 ) -> Result<usize> {
581 let AutoDeviceMapParams::Vision {
582 max_seq_len,
583 max_batch_size,
584 max_image_shape: _,
585 max_num_images,
586 } = params
587 else {
588 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
589 };
590
591 let cfg: Idefics2Config = serde_json::from_str(config)?;
592
593 let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
594 let img_seq_len = (num_patches + 1) * max_num_images;
595
596 let max_text_attn = {
597 let max_seq_len = img_seq_len + max_seq_len;
599 max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len
600 };
601
602 Ok(max_text_attn)
603 }
604
605 fn non_mapped_max_act_size_elems(
606 &self,
607 config: &str,
608 params: &AutoDeviceMapParams,
609 ) -> Result<usize> {
610 let AutoDeviceMapParams::Vision {
611 max_seq_len: _,
612 max_batch_size,
613 max_image_shape: _,
614 max_num_images,
615 } = params
616 else {
617 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
618 };
619
620 let cfg: Idefics2Config = serde_json::from_str(config)?;
621
622 let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
623 let img_seq_len = num_patches + 1;
624
625 let max_vision_attn = {
626 let images_factor = 5;
628
629 (max_batch_size * images_factor * max_num_images)
630 * cfg.vision_config.num_attention_heads
631 * img_seq_len
632 * img_seq_len
633 };
634
635 Ok(max_vision_attn)
636 }
637
638 fn non_mapped_size_in_bytes(
639 &self,
640 config: &str,
641 dtype: DType,
642 weight_pack_factor: usize,
643 ) -> Result<usize> {
644 let cfg: Idefics2Config = serde_json::from_str(config)?;
645 let text_elems = {
646 let tie_word_embeddings = cfg.tie_word_embeddings;
647 let cfg = &cfg.text_config;
648
649 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
650 let lm_head = if !tie_word_embeddings {
651 cfg.hidden_size * cfg.vocab_size
652 } else {
653 0
654 };
655 let norm = cfg.hidden_size;
656 embed_tokens + lm_head + norm
657 };
658
659 let connector_elems = {
660 let tcfg = &cfg.text_config;
661 let vcfg = &cfg.vision_config;
662 let gate_proj = vcfg.hidden_size * tcfg.intermediate_size;
663 let up_proj = vcfg.hidden_size * tcfg.intermediate_size;
664 let down_proj = tcfg.intermediate_size * tcfg.hidden_size;
665
666 let perceiver_elems = {
667 let tcfg = &cfg.text_config;
668 let pcfg = &cfg.perceiver_config;
669
670 let n_latents = pcfg.resampler_n_latents;
671 let hidden_size = tcfg.hidden_size;
672 let depth = pcfg.resampler_depth;
673
674 let norm = tcfg.hidden_size;
675 let latents = n_latents * hidden_size;
676
677 let layer_elems = {
678 let input_latents_norm = hidden_size;
679 let input_context_norm = hidden_size;
680 let post_attn_norm = hidden_size;
681
682 let num_heads = pcfg.resampler_n_heads;
683 let head_dim = pcfg.resampler_head_dim;
684 let num_key_value_heads = pcfg.num_key_value_heads;
685
686 let q_proj = hidden_size * num_heads * head_dim;
687 let k_proj = hidden_size * num_key_value_heads * head_dim;
688 let v_proj = hidden_size * num_key_value_heads * head_dim;
689 let o_proj = num_heads * head_dim * hidden_size;
690
691 let gate_proj = hidden_size * hidden_size * 4;
692 let up_proj = hidden_size * hidden_size * 4;
693 let down_proj = hidden_size * 4 * hidden_size;
694
695 input_latents_norm
696 + input_context_norm
697 + post_attn_norm
698 + q_proj
699 + k_proj
700 + v_proj
701 + o_proj
702 + gate_proj
703 + up_proj
704 + down_proj
705 };
706
707 norm + latents + layer_elems * depth
708 };
709
710 gate_proj + up_proj + down_proj + perceiver_elems
711 };
712
713 let vision_transformer = {
714 let cfg = &cfg.vision_config;
715
716 let post_layernorm = cfg.hidden_size;
717
718 let conv_config = Conv2dConfig {
719 stride: cfg.patch_size,
720 ..Default::default()
721 };
722 let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
723 * cfg.patch_size
724 * cfg.patch_size;
725
726 let num_patches_per_side = cfg.image_size / cfg.patch_size;
727 let num_patches = num_patches_per_side.pow(2);
728 let position_embedding = num_patches * cfg.hidden_size;
729
730 let layer_elems = {
731 let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
732 let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
733
734 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
735 let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
736
737 let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
738 let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
739 let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
740 let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
741
742 layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
743 };
744
745 post_layernorm + patch_embedding + position_embedding + layer_elems
746 };
747
748 let elems = text_elems + connector_elems + vision_transformer;
749
750 Ok(elems * dtype.size_in_bytes())
751 }
752
753 fn layer_sizes_in_bytes(
754 &self,
755 config: &str,
756 dtype: DType,
757 weight_pack_factor: usize,
758 ) -> Result<Vec<usize>> {
759 let cfg: Idefics2Config = serde_json::from_str(config)?;
760 let cfg = cfg.text_config;
761 let per_layer_elems = {
762 let input_layernorm = cfg.hidden_size;
763 let post_attention_layernorm = cfg.hidden_size;
764
765 let size_in = cfg.hidden_size;
766 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
767 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
768 let q_proj = size_in * size_q / weight_pack_factor;
769 let k_proj = size_in * size_kv / weight_pack_factor;
770 let v_proj = size_in * size_kv / weight_pack_factor;
771 let o_proj = size_q * size_in / weight_pack_factor;
772
773 let h_size = cfg.hidden_size;
774 let i_size = cfg.intermediate_size;
775 let gate_proj = h_size * i_size / weight_pack_factor;
776 let up_proj = h_size * i_size / weight_pack_factor;
777 let down_proj = i_size * h_size / weight_pack_factor;
778
779 input_layernorm
780 + post_attention_layernorm
781 + q_proj
782 + k_proj
783 + v_proj
784 + o_proj
785 + gate_proj
786 + up_proj
787 + down_proj
788 };
789 Ok(vec![
790 per_layer_elems * dtype.size_in_bytes();
791 cfg.num_hidden_layers
792 ])
793 }
794
795 fn num_layers(&self, config: &str) -> Result<usize> {
796 let cfg: Idefics2Config = serde_json::from_str(config)?;
797 Ok(cfg.text_config.num_hidden_layers)
798 }
799 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
800 let cfg: Idefics2Config = serde_json::from_str(config)?;
801 let cfg = &cfg.text_config;
802
803 let cfg = ModelConfigMetadata {
804 max_seq_len: cfg.max_position_embeddings,
805 num_layers: cfg.num_hidden_layers,
806 hidden_size: cfg.hidden_size,
807 num_kv_heads: cfg.num_key_value_heads,
808 num_attn_heads: cfg.num_attention_heads,
809 sliding_window: cfg.sliding_window,
810 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
811 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
812 };
813
814 Ok(Box::new(cfg))
815 }
816
817 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
818 Some(vec![NonMappedSubModel::Vision])
819 }
820}
821
822pub struct LLaVANextLoader;
828
829pub struct LLaVANextPrefixer;
830
831impl VisionPromptPrefixer for LLaVANextPrefixer {
832 fn prefix_image(&self, _image_index: usize, prompt: &str) -> String {
833 format!("<image>{prompt}")
834 }
835}
836
837impl VisionModelLoader for LLaVANextLoader {
838 fn load(
839 &self,
840 config: &str,
841 use_flash_attn: bool,
842 vb: ShardedVarBuilder,
843 normal_loading_metadata: NormalLoadingMetadata,
844 attention_mechanism: AttentionImplementation,
845 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
846 let mut config: LLaVAConfig = serde_json::from_str(config)?;
847 config.use_flash_attn = use_flash_attn;
848 Ok(Box::new(LLaVANext::new(
849 &config,
850 vb,
851 self.is_gptx(),
852 normal_loading_metadata,
853 attention_mechanism,
854 )?))
855 }
856 fn is_gptx(&self) -> bool {
857 false
858 }
859 fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
860 let mut config: LLaVAConfig = serde_json::from_str(config)?;
861 config.use_flash_attn = use_flash_attn;
862 Ok(Box::new(config))
863 }
864 fn get_processor(
865 &self,
866 model_config: &str,
867 _processor_config: Option<ProcessorConfig>,
868 _preprocessor_config: PreProcessorConfig,
869 _max_edge: Option<u32>,
870 ) -> Arc<dyn Processor + Send + Sync> {
871 Arc::new(LLaVANextProcessor::new(model_config))
872 }
873 fn supports_paged_attention(&self) -> bool {
874 true
875 }
876 fn prefixer(&self) -> Arc<dyn VisionPromptPrefixer> {
877 Arc::new(LLaVANextPrefixer)
878 }
879}
880
881impl IsqModelLoader for LLaVANextLoader {
882 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
883 Ok(vec![
884 Regex::new(r"lm_head\.(weight|bias)$")?,
885 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
887 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
888 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
889 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
890 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
892 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
893 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
894 ])
895 }
896}
897
898impl DeviceMappedModelLoader for LLaVANextLoader {
899 fn mapped_max_act_size_elems(
900 &self,
901 config: &str,
902 params: &AutoDeviceMapParams,
903 _prompt_chunksize: usize,
904 ) -> Result<usize> {
905 let AutoDeviceMapParams::Vision {
906 max_seq_len,
907 max_batch_size,
908 max_image_shape,
909 max_num_images,
910 } = params
911 else {
912 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
913 };
914
915 let config: LLaVAConfig = serde_json::from_str(config)?;
916
917 #[allow(clippy::cast_possible_truncation)]
918 let img_seq_len =
919 llava_next_inputs_processor::LLaVANextInputProcessor::get_num_image_tokens(
920 &config,
921 (max_image_shape.0 as u32, max_image_shape.1 as u32),
922 );
923 let img_seq_len = img_seq_len * max_num_images;
924
925 let max_text_attn = {
926 let cfg = &config.text_config;
927 let max_seq_len = img_seq_len + max_seq_len;
929
930 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
931 };
932
933 Ok(max_text_attn)
934 }
935
936 fn non_mapped_max_act_size_elems(
937 &self,
938 config: &str,
939 params: &AutoDeviceMapParams,
940 ) -> Result<usize> {
941 let AutoDeviceMapParams::Vision {
942 max_seq_len: _,
943 max_batch_size,
944 max_image_shape,
945 max_num_images,
946 } = params
947 else {
948 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
949 };
950
951 let config: LLaVAConfig = serde_json::from_str(config)?;
952
953 #[allow(clippy::cast_possible_truncation)]
954 let img_seq_len =
955 llava_next_inputs_processor::LLaVANextInputProcessor::get_num_image_tokens(
956 &config,
957 (max_image_shape.0 as u32, max_image_shape.1 as u32),
958 );
959
960 let max_vision_attn = {
961 (max_batch_size * max_num_images)
962 * config.vision_config.num_attention_heads
963 * img_seq_len
964 * img_seq_len
965 };
966
967 Ok(max_vision_attn)
968 }
969
970 fn non_mapped_size_in_bytes(
971 &self,
972 config: &str,
973 dtype: DType,
974 weight_pack_factor: usize,
975 ) -> Result<usize> {
976 let cfg: LLaVAConfig = serde_json::from_str(config)?;
977 let text_elems = {
978 let cfg = &cfg.text_config;
979 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
980 let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
981 let norm = cfg.hidden_size;
982 embed_tokens + lm_head + norm
983 };
984
985 let image_newline = cfg.text_config.hidden_size;
986 let mmproj = {
987 let linear_1 = cfg.vision_config.hidden_size * cfg.text_config.hidden_size
988 + cfg.text_config.hidden_size;
989 let linear_2 = cfg.text_config.hidden_size * cfg.text_config.hidden_size
990 + cfg.text_config.hidden_size;
991
992 linear_1 + linear_2
993 };
994 let vision_tower = get_clip_vit_num_elems(&cfg.to_clip_config());
995
996 let elems = text_elems + image_newline + mmproj + vision_tower;
997 Ok(elems * dtype.size_in_bytes())
998 }
999
1000 fn layer_sizes_in_bytes(
1001 &self,
1002 config: &str,
1003 dtype: DType,
1004 weight_pack_factor: usize,
1005 ) -> Result<Vec<usize>> {
1006 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1007 let per_layer_elems = {
1008 let cfg = &cfg.text_config;
1009 let input_layernorm = cfg.hidden_size;
1010 let post_attention_layernorm = cfg.hidden_size;
1011
1012 let size_in = cfg.hidden_size;
1013 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1014 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1015 let q_proj = size_in * size_q / weight_pack_factor;
1016 let k_proj = size_in * size_kv / weight_pack_factor;
1017 let v_proj = size_in * size_kv / weight_pack_factor;
1018 let o_proj = size_q * size_in / weight_pack_factor;
1019
1020 let h_size = cfg.hidden_size;
1021 let i_size = cfg.intermediate_size;
1022 let gate_proj = h_size * i_size / weight_pack_factor;
1023 let up_proj = h_size * i_size / weight_pack_factor;
1024 let down_proj = i_size * h_size / weight_pack_factor;
1025
1026 input_layernorm
1027 + post_attention_layernorm
1028 + q_proj
1029 + k_proj
1030 + v_proj
1031 + o_proj
1032 + gate_proj
1033 + up_proj
1034 + down_proj
1035 };
1036 Ok(vec![
1037 per_layer_elems * dtype.size_in_bytes();
1038 cfg.text_config.num_hidden_layers
1039 ])
1040 }
1041
1042 fn num_layers(&self, config: &str) -> Result<usize> {
1043 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1044 Ok(cfg.text_config.num_hidden_layers)
1045 }
1046
1047 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1048 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1049 let cfg = &cfg.text_config;
1050
1051 let cfg = ModelConfigMetadata {
1052 max_seq_len: cfg.max_position_embeddings,
1053 num_layers: cfg.num_hidden_layers,
1054 hidden_size: cfg.hidden_size,
1055 num_kv_heads: cfg.num_key_value_heads,
1056 num_attn_heads: cfg.num_attention_heads,
1057 sliding_window: cfg.sliding_window,
1058 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1059 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1060 };
1061
1062 Ok(Box::new(cfg))
1063 }
1064
1065 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1066 Some(vec![NonMappedSubModel::Vision])
1067 }
1068}
1069
1070pub struct LLaVALoader;
1076
1077pub struct LLaVAPrefixer;
1078
1079impl VisionPromptPrefixer for LLaVAPrefixer {
1080 fn prefix_image(&self, _image_index: usize, prompt: &str) -> String {
1081 format!("<image>{prompt}")
1082 }
1083}
1084
1085impl VisionModelLoader for LLaVALoader {
1086 fn load(
1087 &self,
1088 config: &str,
1089 use_flash_attn: bool,
1090 vb: ShardedVarBuilder,
1091 normal_loading_metadata: NormalLoadingMetadata,
1092 attention_mechanism: AttentionImplementation,
1093 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
1094 let mut config: LLaVAConfig = serde_json::from_str(config)?;
1095 config.use_flash_attn = use_flash_attn;
1096 Ok(Box::new(LLaVA::new(
1097 &config,
1098 vb,
1099 self.is_gptx(),
1100 normal_loading_metadata,
1101 attention_mechanism,
1102 )?))
1103 }
1104 fn is_gptx(&self) -> bool {
1105 false
1106 }
1107 fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
1108 let mut config: LLaVAConfig = serde_json::from_str(config)?;
1109 config.use_flash_attn = use_flash_attn;
1110 Ok(Box::new(config))
1111 }
1112 fn get_processor(
1113 &self,
1114 model_config: &str,
1115 _processor_config: Option<ProcessorConfig>,
1116 _preprocessor_config: PreProcessorConfig,
1117 _max_edge: Option<u32>,
1118 ) -> Arc<dyn Processor + Send + Sync> {
1119 Arc::new(LLaVAProcessor::new(model_config))
1120 }
1121 fn supports_paged_attention(&self) -> bool {
1122 true
1123 }
1124 fn prefixer(&self) -> Arc<dyn VisionPromptPrefixer> {
1125 Arc::new(LLaVAPrefixer)
1126 }
1127}
1128
1129impl IsqModelLoader for LLaVALoader {
1130 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1131 Ok(vec![
1132 Regex::new(r"lm_head\.(weight|bias)$")?,
1133 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1135 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1136 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1137 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1138 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1140 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1141 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1142 ])
1143 }
1144}
1145
1146impl DeviceMappedModelLoader for LLaVALoader {
1147 fn mapped_max_act_size_elems(
1148 &self,
1149 config: &str,
1150 params: &AutoDeviceMapParams,
1151 _prompt_chunksize: usize,
1152 ) -> Result<usize> {
1153 let AutoDeviceMapParams::Vision {
1154 max_seq_len,
1155 max_batch_size,
1156 max_image_shape: _,
1157 max_num_images,
1158 } = params
1159 else {
1160 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1161 };
1162
1163 let config: LLaVAConfig = serde_json::from_str(config)?;
1164
1165 let img_seq_len =
1166 llava_inputs_processor::LLaVAInputProcessor::get_num_image_tokens(&config);
1167 let img_seq_len = img_seq_len * max_num_images;
1168
1169 let max_text_attn = {
1170 let cfg = &config.text_config;
1171 let max_seq_len = img_seq_len + max_seq_len;
1173
1174 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
1175 };
1176
1177 Ok(max_text_attn)
1178 }
1179
1180 fn non_mapped_max_act_size_elems(
1181 &self,
1182 config: &str,
1183 params: &AutoDeviceMapParams,
1184 ) -> Result<usize> {
1185 let AutoDeviceMapParams::Vision {
1186 max_seq_len: _,
1187 max_batch_size,
1188 max_image_shape: _,
1189 max_num_images,
1190 } = params
1191 else {
1192 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1193 };
1194
1195 let config: LLaVAConfig = serde_json::from_str(config)?;
1196
1197 let img_seq_len =
1198 llava_inputs_processor::LLaVAInputProcessor::get_num_image_tokens(&config);
1199
1200 let max_vision_attn = {
1201 (max_batch_size * max_num_images)
1202 * config.vision_config.num_attention_heads
1203 * img_seq_len
1204 * img_seq_len
1205 };
1206
1207 Ok(max_vision_attn)
1208 }
1209
1210 fn non_mapped_size_in_bytes(
1211 &self,
1212 config: &str,
1213 dtype: DType,
1214 weight_pack_factor: usize,
1215 ) -> Result<usize> {
1216 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1217 let text_elems = {
1218 let cfg = &cfg.text_config;
1219 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1220 let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1221 let norm = cfg.hidden_size;
1222 embed_tokens + lm_head + norm
1223 };
1224
1225 let image_newline = cfg.text_config.hidden_size;
1226 let mmproj = {
1227 let linear_1 = cfg.vision_config.hidden_size * cfg.text_config.hidden_size
1228 + cfg.text_config.hidden_size;
1229 let linear_2 = cfg.text_config.hidden_size * cfg.text_config.hidden_size
1230 + cfg.text_config.hidden_size;
1231
1232 linear_1 + linear_2
1233 };
1234 let vision_tower = get_clip_vit_num_elems(&cfg.to_clip_config());
1235
1236 let elems = text_elems + image_newline + mmproj + vision_tower;
1237 Ok(elems * dtype.size_in_bytes())
1238 }
1239
1240 fn layer_sizes_in_bytes(
1241 &self,
1242 config: &str,
1243 dtype: DType,
1244 weight_pack_factor: usize,
1245 ) -> Result<Vec<usize>> {
1246 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1247 let per_layer_elems = {
1248 let cfg = &cfg.text_config;
1249 let input_layernorm = cfg.hidden_size;
1250 let post_attention_layernorm = cfg.hidden_size;
1251
1252 let size_in = cfg.hidden_size;
1253 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1254 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1255 let q_proj = size_in * size_q / weight_pack_factor;
1256 let k_proj = size_in * size_kv / weight_pack_factor;
1257 let v_proj = size_in * size_kv / weight_pack_factor;
1258 let o_proj = size_q * size_in / weight_pack_factor;
1259
1260 let h_size = cfg.hidden_size;
1261 let i_size = cfg.intermediate_size;
1262 let gate_proj = h_size * i_size / weight_pack_factor;
1263 let up_proj = h_size * i_size / weight_pack_factor;
1264 let down_proj = i_size * h_size / weight_pack_factor;
1265
1266 input_layernorm
1267 + post_attention_layernorm
1268 + q_proj
1269 + k_proj
1270 + v_proj
1271 + o_proj
1272 + gate_proj
1273 + up_proj
1274 + down_proj
1275 };
1276 Ok(vec![
1277 per_layer_elems * dtype.size_in_bytes();
1278 cfg.text_config.num_hidden_layers
1279 ])
1280 }
1281
1282 fn num_layers(&self, config: &str) -> Result<usize> {
1283 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1284 Ok(cfg.text_config.num_hidden_layers)
1285 }
1286
1287 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1288 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1289 let cfg = &cfg.text_config;
1290
1291 let cfg = ModelConfigMetadata {
1292 max_seq_len: cfg.max_position_embeddings,
1293 num_layers: cfg.num_hidden_layers,
1294 hidden_size: cfg.hidden_size,
1295 num_kv_heads: cfg.num_key_value_heads,
1296 num_attn_heads: cfg.num_attention_heads,
1297 sliding_window: cfg.sliding_window,
1298 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1299 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1300 };
1301
1302 Ok(Box::new(cfg))
1303 }
1304
1305 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1306 Some(vec![NonMappedSubModel::Vision])
1307 }
1308}
1309
1310pub struct VLlamaLoader;
1316
1317pub struct VLlamaPrefixer;
1318
1319impl VisionPromptPrefixer for VLlamaPrefixer {
1320 fn prefix_image(&self, _image_index: usize, prompt: &str) -> String {
1321 format!("<|image|>{prompt}")
1322 }
1323}
1324
1325impl VisionModelLoader for VLlamaLoader {
1326 fn load(
1327 &self,
1328 config: &str,
1329 use_flash_attn: bool,
1330 vb: ShardedVarBuilder,
1331 normal_loading_metadata: NormalLoadingMetadata,
1332 attention_mechanism: AttentionImplementation,
1333 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
1334 let mut config: MLlamaConfig = serde_json::from_str(config)?;
1335 config.text_config.use_flash_attn = use_flash_attn;
1336 Ok(Box::new(MLlamaModel::new(
1337 &config,
1338 vb,
1339 self.is_gptx(),
1340 normal_loading_metadata,
1341 attention_mechanism,
1342 )?))
1343 }
1344 fn is_gptx(&self) -> bool {
1345 true
1346 }
1347 fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
1348 let mut config: MLlamaConfig = serde_json::from_str(config)?;
1349 config.text_config.use_flash_attn = use_flash_attn;
1350 Ok(Box::new(config))
1351 }
1352 fn get_processor(
1353 &self,
1354 _model_config: &str,
1355 _processor_config: Option<ProcessorConfig>,
1356 _preprocessor_config: PreProcessorConfig,
1357 _max_edge: Option<u32>,
1358 ) -> Arc<dyn Processor + Send + Sync> {
1359 Arc::new(MLlamaProcessor::new())
1360 }
1361 fn supports_paged_attention(&self) -> bool {
1362 false
1363 }
1364 fn prefixer(&self) -> Arc<dyn VisionPromptPrefixer> {
1365 Arc::new(VLlamaPrefixer)
1366 }
1367}
1368
1369impl IsqModelLoader for VLlamaLoader {
1370 fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
1371 let config: MLlamaConfig = serde_json::from_str(config)?;
1372 let cross_attn_layers = &config.text_config.cross_attention_layers;
1373 let transformer_layers =
1374 (0..config.text_config.num_hidden_layers).filter(|i| !cross_attn_layers.contains(i));
1375 let mut text_regexes = Vec::new();
1376 for layer in transformer_layers {
1377 text_regexes.extend(vec![
1378 Regex::new(&format!(
1380 r"language_model.model.layers\.{layer}\.self_attn\.q_proj\.(weight|bias)$"
1381 ))?,
1382 Regex::new(&format!(
1383 r"language_model.model.layers\.{layer}\.self_attn\.k_proj\.(weight|bias)$"
1384 ))?,
1385 Regex::new(&format!(
1386 r"language_model.model.layers\.{layer}\.self_attn\.v_proj\.(weight|bias)$"
1387 ))?,
1388 Regex::new(&format!(
1389 r"language_model.model.layers\.{layer}\.self_attn\.o_proj\.(weight|bias)$"
1390 ))?,
1391 Regex::new(&format!(
1393 r"language_model.model.layers\.{layer}\.mlp\.gate_proj\.(weight|bias)$"
1394 ))?,
1395 Regex::new(&format!(
1396 r"language_model.model.layers\.{layer}\.mlp\.up_proj\.(weight|bias)$"
1397 ))?,
1398 Regex::new(&format!(
1399 r"language_model.model.layers\.{layer}\.mlp\.down_proj\.(weight|bias)$"
1400 ))?,
1401 ]);
1402 }
1403 let vision_regexes = vec![
1404 Regex::new(
1406 r"vision_model.transformer.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
1407 )?,
1408 Regex::new(
1409 r"vision_model.transformer.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$",
1410 )?,
1411 Regex::new(
1412 r"vision_model.transformer.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$",
1413 )?,
1414 Regex::new(
1415 r"vision_model.transformer.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$",
1416 )?,
1417 Regex::new(
1419 r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
1420 )?,
1421 Regex::new(
1422 r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$",
1423 )?,
1424 Regex::new(
1425 r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$",
1426 )?,
1427 Regex::new(
1428 r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$",
1429 )?,
1430 Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
1432 Regex::new(r"layers\.(\d+)\.mlp\.fc2\.(weight|bias)$")?,
1433 ];
1434
1435 Ok([text_regexes, vision_regexes].concat())
1436 }
1437}
1438
1439impl DeviceMappedModelLoader for VLlamaLoader {
1440 fn mapped_max_act_size_elems(
1441 &self,
1442 config: &str,
1443 params: &AutoDeviceMapParams,
1444 _prompt_chunksize: usize,
1445 ) -> Result<usize> {
1446 let AutoDeviceMapParams::Vision {
1447 max_seq_len,
1448 max_batch_size,
1449 max_image_shape: _,
1450 max_num_images,
1451 } = params
1452 else {
1453 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1454 };
1455
1456 let config: MLlamaConfig = serde_json::from_str(config)?;
1457
1458 let img_seq_len = {
1459 let cfg = &config.vision_config;
1460 let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
1461 let num_padding_patches = (8 - (num_patches as isize % 8)) % 8;
1462 cfg.max_num_tiles * (num_patches as isize + num_padding_patches) as usize
1463 };
1464 let img_seq_len = img_seq_len * max_num_images;
1465
1466 let max_cross_text_attn = {
1467 let cfg = &config.text_config;
1468 max_batch_size * cfg.num_attention_heads * img_seq_len * img_seq_len
1469 };
1470
1471 let max_self_text_attn = {
1472 let cfg = &config.text_config;
1473 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
1474 };
1475
1476 Ok(max_self_text_attn.max(max_cross_text_attn))
1477 }
1478
1479 fn non_mapped_max_act_size_elems(
1480 &self,
1481 config: &str,
1482 params: &AutoDeviceMapParams,
1483 ) -> Result<usize> {
1484 let AutoDeviceMapParams::Vision {
1485 max_seq_len: _,
1486 max_batch_size,
1487 max_image_shape: _,
1488 max_num_images,
1489 } = params
1490 else {
1491 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1492 };
1493
1494 let config: MLlamaConfig = serde_json::from_str(config)?;
1495
1496 let img_seq_len = {
1497 let cfg = &config.vision_config;
1498 let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
1499 let num_padding_patches = (8 - (num_patches as isize % 8)) % 8;
1500 cfg.max_num_tiles * (num_patches as isize + num_padding_patches) as usize
1501 };
1502 let max_vision_attn = {
1503 let cfg = &config.vision_config;
1504 (max_batch_size * max_num_images) * cfg.num_attention_heads * img_seq_len * img_seq_len
1505 };
1506
1507 Ok(max_vision_attn)
1508 }
1509
1510 fn non_mapped_size_in_bytes(
1511 &self,
1512 config: &str,
1513 dtype: DType,
1514 weight_pack_factor: usize,
1515 ) -> Result<usize> {
1516 let config: MLlamaConfig = serde_json::from_str(config)?;
1517 let text_elems = {
1518 let cfg = &config.text_config;
1519 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1520 let lm_head = if !cfg.tie_word_embeddings {
1521 cfg.hidden_size * cfg.vocab_size
1522 } else {
1523 0
1524 };
1525 let norm = cfg.hidden_size;
1526 embed_tokens + lm_head + norm
1527 };
1528
1529 let vision_elems = {
1530 let cfg = &config.vision_config;
1531
1532 let conv_cfg = Conv2dConfig {
1533 stride: cfg.patch_size,
1534 ..Default::default()
1535 };
1536 let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_cfg.groups
1537 * cfg.patch_size
1538 * cfg.patch_size;
1539
1540 let class_embedding = cfg.hidden_size;
1541
1542 let gated_positional_embedding = {
1543 let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
1544 let embedding = num_patches * cfg.hidden_size;
1545 let tile_embedding = (cfg.max_aspect_ratio_id() + 1)
1546 * (cfg.max_num_tiles * num_patches * cfg.hidden_size);
1547
1548 embedding + tile_embedding
1549 };
1550
1551 let pre_tile_positional_embedding =
1552 (cfg.max_aspect_ratio_id() + 1) * (cfg.max_num_tiles * cfg.hidden_size);
1553 let post_tile_positional_embedding =
1554 (cfg.max_aspect_ratio_id() + 1) * (cfg.max_num_tiles * cfg.hidden_size);
1555
1556 let layernorm_pre = cfg.hidden_size;
1557 let layernorm_post = cfg.hidden_size;
1558
1559 let encoder_layer = {
1560 let input_layernorm = cfg.hidden_size + cfg.hidden_size;
1561 let post_attention_layernorm = cfg.hidden_size + cfg.hidden_size;
1562
1563 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
1564 let q_proj =
1565 cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1566 let k_proj =
1567 cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1568 let v_proj =
1569 cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1570 let o_proj =
1571 cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1572
1573 let fc1 = (cfg.hidden_size * cfg.intermediate_size) / weight_pack_factor
1574 + cfg.intermediate_size;
1575 let fc2 = (cfg.intermediate_size * cfg.hidden_size) / weight_pack_factor
1576 + cfg.hidden_size;
1577
1578 input_layernorm
1579 + post_attention_layernorm
1580 + q_proj
1581 + k_proj
1582 + v_proj
1583 + o_proj
1584 + fc1
1585 + fc2
1586 };
1587
1588 patch_embedding
1589 + class_embedding
1590 + gated_positional_embedding
1591 + pre_tile_positional_embedding
1592 + post_tile_positional_embedding
1593 + layernorm_pre
1594 + layernorm_post
1595 + encoder_layer * (cfg.num_hidden_layers + cfg.num_global_layers)
1596 };
1597
1598 let elems = text_elems + vision_elems;
1599 Ok(elems * dtype.size_in_bytes())
1600 }
1601
1602 fn layer_sizes_in_bytes(
1603 &self,
1604 config: &str,
1605 dtype: DType,
1606 weight_pack_factor: usize,
1607 ) -> Result<Vec<usize>> {
1608 let config: MLlamaConfig = serde_json::from_str(config)?;
1609 let cfg = &config.text_config;
1610
1611 let mut layer_sizes = Vec::new();
1612
1613 for i in 0..cfg.num_hidden_layers {
1614 let weight_pack_factor = if cfg.cross_attention_layers.contains(&i) {
1615 1
1617 } else {
1618 weight_pack_factor
1619 };
1620
1621 let per_layer_elems = {
1622 let input_layernorm = cfg.hidden_size;
1623 let post_attention_layernorm = cfg.hidden_size;
1624
1625 let size_in = cfg.hidden_size;
1626 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1627 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1628 let q_proj = size_in * size_q / weight_pack_factor;
1629 let k_proj = size_in * size_kv / weight_pack_factor;
1630 let v_proj = size_in * size_kv / weight_pack_factor;
1631 let o_proj = size_q * size_in / weight_pack_factor;
1632
1633 let h_size = cfg.hidden_size;
1634 let i_size = cfg.intermediate_size;
1635 let gate_proj = h_size * i_size / weight_pack_factor;
1636 let up_proj = h_size * i_size / weight_pack_factor;
1637 let down_proj = i_size * h_size / weight_pack_factor;
1638
1639 input_layernorm
1640 + post_attention_layernorm
1641 + q_proj
1642 + k_proj
1643 + v_proj
1644 + o_proj
1645 + gate_proj
1646 + up_proj
1647 + down_proj
1648 };
1649
1650 layer_sizes.push(per_layer_elems * dtype.size_in_bytes());
1651 }
1652
1653 Ok(layer_sizes)
1654 }
1655
1656 fn num_layers(&self, config: &str) -> Result<usize> {
1657 let config: MLlamaConfig = serde_json::from_str(config)?;
1658 Ok(config.text_config.num_hidden_layers)
1659 }
1660
1661 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1662 let cfg: MLlamaConfig = serde_json::from_str(config)?;
1663 let cfg = &cfg.text_config;
1664
1665 let cfg = ModelConfigMetadata {
1666 max_seq_len: cfg.max_position_embeddings,
1667 num_layers: cfg.num_hidden_layers,
1668 hidden_size: cfg.hidden_size,
1669 num_kv_heads: cfg.num_key_value_heads,
1670 num_attn_heads: cfg.num_attention_heads,
1671 sliding_window: None,
1672 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1673 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1674 };
1675
1676 Ok(Box::new(cfg))
1677 }
1678
1679 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1680 Some(vec![NonMappedSubModel::Vision])
1681 }
1682}
1683
1684pub struct Qwen2VLLoader;
1690
1691pub struct Qwen2VLPrefixer;
1692
1693impl VisionPromptPrefixer for Qwen2VLPrefixer {
1694 fn prefix_image(&self, _image_index: usize, prompt: &str) -> String {
1695 format!(
1696 "{}{}{}{prompt}",
1697 Qwen2VLProcessor::VISION_START,
1698 Qwen2VLProcessor::IMAGE_PAD,
1699 Qwen2VLProcessor::VISION_END
1700 )
1701 }
1702}
1703
1704impl VisionModelLoader for Qwen2VLLoader {
1705 fn load(
1706 &self,
1707 config: &str,
1708 _use_flash_attn: bool,
1709 vb: ShardedVarBuilder,
1710 normal_loading_metadata: NormalLoadingMetadata,
1711 attention_mechanism: AttentionImplementation,
1712 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
1713 let config: Qwen2VLConfig = serde_json::from_str(config)?;
1714 Ok(Box::new(Qwen2VLModel::new(
1715 &config,
1716 vb,
1717 self.is_gptx(),
1718 normal_loading_metadata,
1719 attention_mechanism,
1720 )?))
1721 }
1722 fn is_gptx(&self) -> bool {
1723 true
1724 }
1725 fn get_config_repr(&self, config: &str, _use_flash_attn: bool) -> Result<Box<dyn Debug>> {
1726 let config: Qwen2VLConfig = serde_json::from_str(config)?;
1727 Ok(Box::new(config))
1728 }
1729 fn get_processor(
1730 &self,
1731 _model_config: &str,
1732 _processor_config: Option<ProcessorConfig>,
1733 _preprocessor_config: PreProcessorConfig,
1734 max_edge: Option<u32>,
1735 ) -> Arc<dyn Processor + Send + Sync> {
1736 Arc::new(Qwen2VLProcessor::new(max_edge))
1737 }
1738 fn supports_paged_attention(&self) -> bool {
1739 false
1740 }
1741 fn prefixer(&self) -> Arc<dyn VisionPromptPrefixer> {
1742 Arc::new(Qwen2VLPrefixer)
1743 }
1744}
1745
1746impl IsqModelLoader for Qwen2VLLoader {
1747 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1748 Ok(vec![
1749 Regex::new(r"lm_head\.(weight|bias)$")?,
1750 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1752 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1753 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1754 Regex::new(r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$")?,
1755 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1757 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1758 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1759 ])
1760 }
1761}
1762
1763impl DeviceMappedModelLoader for Qwen2VLLoader {
1764 fn mapped_max_act_size_elems(
1765 &self,
1766 config: &str,
1767 params: &AutoDeviceMapParams,
1768 _prompt_chunksize: usize,
1769 ) -> Result<usize> {
1770 let AutoDeviceMapParams::Vision {
1771 max_seq_len,
1772 max_batch_size,
1773 max_image_shape,
1774 max_num_images,
1775 } = params
1776 else {
1777 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1778 };
1779
1780 let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
1781
1782 let img_seq_len = {
1783 let cfg = &cfg.vision_config;
1784 let grid_t = max_num_images / cfg.temporal_patch_size;
1785 let grid_h = max_image_shape.0 / cfg.patch_size;
1786 let grid_w = max_image_shape.1 / cfg.patch_size;
1787 grid_t * grid_h * grid_w
1788 };
1789 let img_seq_len = img_seq_len * max_num_images;
1790
1791 let max_text_attn = {
1792 let max_seq_len = img_seq_len + max_seq_len;
1794 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
1795 };
1796
1797 Ok(max_text_attn)
1798 }
1799
1800 fn non_mapped_max_act_size_elems(
1801 &self,
1802 config: &str,
1803 params: &AutoDeviceMapParams,
1804 ) -> Result<usize> {
1805 let AutoDeviceMapParams::Vision {
1806 max_seq_len: _,
1807 max_batch_size,
1808 max_image_shape,
1809 max_num_images,
1810 } = params
1811 else {
1812 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1813 };
1814
1815 let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
1816
1817 let img_seq_len = {
1818 let cfg = &cfg.vision_config;
1819 let grid_t = max_num_images / cfg.temporal_patch_size;
1820 let grid_h = max_image_shape.0 / cfg.patch_size;
1821 let grid_w = max_image_shape.1 / cfg.patch_size;
1822 grid_t * grid_h * grid_w
1823 };
1824
1825 let max_vision_attn = {
1826 let cfg = &cfg.vision_config;
1827 (max_batch_size * max_num_images) * cfg.num_heads * img_seq_len * img_seq_len
1828 };
1829
1830 Ok(max_vision_attn)
1831 }
1832
1833 fn non_mapped_size_in_bytes(
1834 &self,
1835 config: &str,
1836 dtype: DType,
1837 weight_pack_factor: usize,
1838 ) -> Result<usize> {
1839 let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
1840 let text_elems = {
1841 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1842 let lm_head = if !cfg.tie_word_embeddings {
1843 cfg.hidden_size * cfg.vocab_size
1844 } else {
1845 0
1846 };
1847 let norm = cfg.hidden_size;
1848 embed_tokens + lm_head + norm
1849 };
1850
1851 let patch_merger = {
1852 let cfg = &cfg.vision_config;
1853 let hidden_size = cfg.embed_dim * cfg.spatial_merge_size.pow(2);
1854
1855 let mlp0 = hidden_size * hidden_size + hidden_size;
1856 let mlp2 = hidden_size * cfg.hidden_size + cfg.hidden_size;
1857
1858 let ln_q = cfg.embed_dim + bias_if!(true, cfg.embed_dim);
1859
1860 mlp0 + mlp2 + ln_q
1861 };
1862
1863 let patch_embed = {
1864 let cfg = &cfg.vision_config;
1865 let conv_cfg = Conv3dConfig {
1866 stride: cfg.patch_size,
1867 ..Default::default()
1868 };
1869 let kernel_sizes = [cfg.temporal_patch_size, cfg.patch_size, cfg.patch_size];
1870 cfg.in_channels * cfg.embed_dim / conv_cfg.groups
1871 * kernel_sizes[0]
1872 * kernel_sizes[1]
1873 * kernel_sizes[2]
1874 };
1875
1876 let encoder_layer = {
1877 let cfg = &cfg.vision_config;
1878 let norm1 = cfg.embed_dim + bias_if!(true, cfg.embed_dim);
1879 let norm2 = cfg.embed_dim + bias_if!(true, cfg.embed_dim);
1880
1881 #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
1882 let mlp_hidden_dim = (cfg.embed_dim as f64 * cfg.mlp_ratio) as usize;
1883 let fc1 = cfg.embed_dim * mlp_hidden_dim + mlp_hidden_dim;
1884 let fc2 = cfg.embed_dim * mlp_hidden_dim + cfg.embed_dim;
1885
1886 let qkv = cfg.embed_dim * cfg.embed_dim * 3 + cfg.embed_dim * 3;
1887 let out = cfg.embed_dim * cfg.embed_dim + cfg.embed_dim;
1888
1889 norm1 + norm2 + fc1 + fc2 + qkv + out
1890 };
1891
1892 let elems =
1893 text_elems + patch_merger + patch_embed + encoder_layer * cfg.vision_config.depth;
1894
1895 Ok(elems * dtype.size_in_bytes())
1896 }
1897
1898 fn layer_sizes_in_bytes(
1899 &self,
1900 config: &str,
1901 dtype: DType,
1902 weight_pack_factor: usize,
1903 ) -> Result<Vec<usize>> {
1904 let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
1905 let per_layer_elems = {
1906 let input_layernorm = cfg.hidden_size;
1907 let post_attention_layernorm = cfg.hidden_size;
1908
1909 let size_in = cfg.hidden_size;
1910 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1911 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1912 let q_proj = size_in * size_q / weight_pack_factor + size_q;
1913 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
1914 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
1915 let o_proj = size_q * size_in / weight_pack_factor;
1916
1917 let h_size = cfg.hidden_size;
1918 let i_size = cfg.intermediate_size;
1919 let gate_proj = h_size * i_size / weight_pack_factor;
1920 let up_proj = h_size * i_size / weight_pack_factor;
1921 let down_proj = i_size * h_size / weight_pack_factor;
1922
1923 input_layernorm
1924 + post_attention_layernorm
1925 + q_proj
1926 + k_proj
1927 + v_proj
1928 + o_proj
1929 + gate_proj
1930 + up_proj
1931 + down_proj
1932 };
1933 Ok(vec![
1934 per_layer_elems * dtype.size_in_bytes();
1935 cfg.num_hidden_layers
1936 ])
1937 }
1938
1939 fn num_layers(&self, config: &str) -> Result<usize> {
1940 let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
1941 Ok(cfg.num_hidden_layers)
1942 }
1943
1944 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1945 let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
1946
1947 let cfg = ModelConfigMetadata {
1948 max_seq_len: cfg.max_position_embeddings,
1949 num_layers: cfg.num_hidden_layers,
1950 hidden_size: cfg.hidden_size,
1951 num_kv_heads: cfg.num_key_value_heads,
1952 num_attn_heads: cfg.num_attention_heads,
1953 sliding_window: cfg.sliding_window,
1954 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1955 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1956 };
1957
1958 Ok(Box::new(cfg))
1959 }
1960
1961 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1962 Some(vec![NonMappedSubModel::Vision])
1963 }
1964}
1965
1966pub struct Idefics3Loader;
1972
1973pub struct Idefics3Prefixer;
1974
1975impl VisionPromptPrefixer for Idefics3Prefixer {
1976 fn prefix_image(&self, _image_index: usize, prompt: &str) -> String {
1977 prompt.to_string()
1979 }
1980}
1981
1982impl VisionModelLoader for Idefics3Loader {
1983 fn load(
1984 &self,
1985 config: &str,
1986 use_flash_attn: bool,
1987 vb: ShardedVarBuilder,
1988 normal_loading_metadata: NormalLoadingMetadata,
1989 attention_mechanism: AttentionImplementation,
1990 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
1991 let mut config: Idefics3Config = serde_json::from_str(config)?;
1992 config.text_config.use_flash_attn = use_flash_attn;
1993 Ok(Box::new(Idefics3Model::new(
1994 &config,
1995 vb,
1996 self.is_gptx(),
1997 normal_loading_metadata,
1998 attention_mechanism,
1999 )?))
2000 }
2001 fn is_gptx(&self) -> bool {
2002 true
2003 }
2004 fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
2005 let mut config: Idefics3Config = serde_json::from_str(config)?;
2006 config.text_config.use_flash_attn = use_flash_attn;
2007 Ok(Box::new(config))
2008 }
2009 fn get_processor(
2010 &self,
2011 _model_config: &str,
2012 processor_config: Option<ProcessorConfig>,
2013 preprocessor_config: PreProcessorConfig,
2014 max_edge: Option<u32>,
2015 ) -> Arc<dyn Processor + Send + Sync> {
2016 Arc::new(Idefics3Processor::new(
2017 processor_config.unwrap_or_default(),
2018 preprocessor_config,
2019 max_edge,
2020 ))
2021 }
2022 fn supports_paged_attention(&self) -> bool {
2023 true
2024 }
2025 fn prefixer(&self) -> Arc<dyn VisionPromptPrefixer> {
2026 Arc::new(Idefics3Prefixer)
2027 }
2028}
2029
2030impl IsqModelLoader for Idefics3Loader {
2031 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2032 Ok(vec![
2033 Regex::new(r"lm_head\.(weight|bias)$")?,
2034 Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2036 Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2037 Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2038 Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2039 Regex::new(r"model.text_model.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2041 Regex::new(r"model.text_model.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2042 Regex::new(r"model.text_model.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2043 ])
2044 }
2045}
2046
2047impl DeviceMappedModelLoader for Idefics3Loader {
2048 fn mapped_max_act_size_elems(
2049 &self,
2050 config: &str,
2051 params: &AutoDeviceMapParams,
2052 _prompt_chunksize: usize,
2053 ) -> Result<usize> {
2054 let AutoDeviceMapParams::Vision {
2055 max_seq_len,
2056 max_batch_size,
2057 max_image_shape: _,
2058 max_num_images,
2059 } = params
2060 else {
2061 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2062 };
2063
2064 let cfg: Idefics3Config = serde_json::from_str(config)?;
2065
2066 let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2067 let img_seq_len = (num_patches + 1) * max_num_images;
2068
2069 let max_text_attn = {
2070 let max_seq_len = img_seq_len + max_seq_len;
2072 max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len
2073 };
2074
2075 Ok(max_text_attn)
2076 }
2077
2078 fn non_mapped_max_act_size_elems(
2079 &self,
2080 config: &str,
2081 params: &AutoDeviceMapParams,
2082 ) -> Result<usize> {
2083 let AutoDeviceMapParams::Vision {
2084 max_seq_len: _,
2085 max_batch_size,
2086 max_image_shape: _,
2087 max_num_images,
2088 } = params
2089 else {
2090 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2091 };
2092
2093 let cfg: Idefics3Config = serde_json::from_str(config)?;
2094
2095 let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2096 let img_seq_len = num_patches + 1;
2097
2098 let max_vision_attn = {
2099 let images_factor = 5;
2101
2102 (max_batch_size * images_factor * max_num_images)
2103 * cfg.vision_config.num_attention_heads
2104 * img_seq_len
2105 * img_seq_len
2106 };
2107
2108 Ok(max_vision_attn)
2109 }
2110
2111 fn non_mapped_size_in_bytes(
2112 &self,
2113 config: &str,
2114 dtype: DType,
2115 weight_pack_factor: usize,
2116 ) -> Result<usize> {
2117 let cfg: Idefics3Config = serde_json::from_str(config)?;
2118 let text_elems = {
2119 let cfg = &cfg.text_config;
2120
2121 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2122 let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2123 let norm = cfg.hidden_size;
2124 embed_tokens + lm_head + norm
2125 };
2126
2127 let connector_elems = {
2128 let in_dim = cfg.vision_config.hidden_size * cfg.scale_factor.pow(2);
2129 let out_dim = cfg.text_config.hidden_size;
2130
2131 in_dim * out_dim
2132 };
2133
2134 let vision_transformer = {
2135 let cfg = &cfg.vision_config;
2136
2137 let post_layernorm = cfg.hidden_size;
2138
2139 let conv_config = Conv2dConfig {
2140 stride: cfg.patch_size,
2141 ..Default::default()
2142 };
2143 let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
2144 * cfg.patch_size
2145 * cfg.patch_size;
2146
2147 let num_patches_per_side = cfg.image_size / cfg.patch_size;
2148 let num_patches = num_patches_per_side.pow(2);
2149 let position_embedding = num_patches * cfg.hidden_size;
2150
2151 let layer_elems = {
2152 let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2153 let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2154
2155 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
2156 let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
2157
2158 let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2159 let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2160 let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2161 let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2162
2163 layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
2164 };
2165
2166 post_layernorm
2167 + patch_embedding
2168 + position_embedding
2169 + layer_elems * cfg.num_hidden_layers
2170 };
2171
2172 let elems = text_elems + connector_elems + vision_transformer;
2173
2174 Ok(elems * dtype.size_in_bytes())
2175 }
2176
2177 fn layer_sizes_in_bytes(
2178 &self,
2179 config: &str,
2180 dtype: DType,
2181 weight_pack_factor: usize,
2182 ) -> Result<Vec<usize>> {
2183 let cfg: Idefics3Config = serde_json::from_str(config)?;
2184 let cfg = cfg.text_config;
2185 let per_layer_elems = {
2186 let input_layernorm = cfg.hidden_size;
2187 let post_attention_layernorm = cfg.hidden_size;
2188
2189 let size_in = cfg.hidden_size;
2190 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2191 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2192 let q_proj = size_in * size_q / weight_pack_factor;
2193 let k_proj = size_in * size_kv / weight_pack_factor;
2194 let v_proj = size_in * size_kv / weight_pack_factor;
2195 let o_proj = size_q * size_in / weight_pack_factor;
2196
2197 let h_size = cfg.hidden_size;
2198 let i_size = cfg.intermediate_size;
2199 let gate_proj = h_size * i_size / weight_pack_factor;
2200 let up_proj = h_size * i_size / weight_pack_factor;
2201 let down_proj = i_size * h_size / weight_pack_factor;
2202
2203 input_layernorm
2204 + post_attention_layernorm
2205 + q_proj
2206 + k_proj
2207 + v_proj
2208 + o_proj
2209 + gate_proj
2210 + up_proj
2211 + down_proj
2212 };
2213 Ok(vec![
2214 per_layer_elems * dtype.size_in_bytes();
2215 cfg.num_hidden_layers
2216 ])
2217 }
2218
2219 fn num_layers(&self, config: &str) -> Result<usize> {
2220 let cfg: Idefics3Config = serde_json::from_str(config)?;
2221 Ok(cfg.text_config.num_hidden_layers)
2222 }
2223 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2224 let cfg: Idefics3Config = serde_json::from_str(config)?;
2225 let cfg = &cfg.text_config;
2226
2227 let cfg = ModelConfigMetadata {
2228 max_seq_len: cfg.max_position_embeddings,
2229 num_layers: cfg.num_hidden_layers,
2230 hidden_size: cfg.hidden_size,
2231 num_kv_heads: cfg.num_key_value_heads,
2232 num_attn_heads: cfg.num_attention_heads,
2233 sliding_window: None,
2234 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2235 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2236 };
2237
2238 Ok(Box::new(cfg))
2239 }
2240
2241 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
2242 Some(vec![NonMappedSubModel::Vision])
2243 }
2244}
2245
2246pub struct MiniCpmOLoader;
2252
2253pub struct MiniCpmOPrefixer;
2254
2255impl VisionPromptPrefixer for MiniCpmOPrefixer {
2256 fn prefix_image(&self, _image_index: usize, prompt: &str) -> String {
2257 format!("(<image>./</image>){prompt}")
2258 }
2259}
2260
2261impl VisionModelLoader for MiniCpmOLoader {
2262 fn load(
2263 &self,
2264 config: &str,
2265 use_flash_attn: bool,
2266 vb: ShardedVarBuilder,
2267 normal_loading_metadata: NormalLoadingMetadata,
2268 attention_mechanism: AttentionImplementation,
2269 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2270 let mut config: MiniCpmOConfig = serde_json::from_str(config)?;
2271 config.text_config.use_flash_attn = use_flash_attn;
2272 Ok(Box::new(MiniCpmOModel::new(
2273 &config,
2274 vb,
2275 self.is_gptx(),
2276 normal_loading_metadata,
2277 attention_mechanism,
2278 )?))
2279 }
2280 fn is_gptx(&self) -> bool {
2281 true
2282 }
2283 fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
2284 let mut config: MiniCpmOConfig = serde_json::from_str(config)?;
2285 config.text_config.use_flash_attn = use_flash_attn;
2286 Ok(Box::new(config))
2287 }
2288 fn get_processor(
2289 &self,
2290 _model_config: &str,
2291 processor_config: Option<ProcessorConfig>,
2292 preprocessor_config: PreProcessorConfig,
2293 max_edge: Option<u32>,
2294 ) -> Arc<dyn Processor + Send + Sync> {
2295 Arc::new(MiniCpmOProcessor::new(
2296 processor_config.unwrap_or_default(),
2297 preprocessor_config,
2298 max_edge,
2299 ))
2300 }
2301 fn supports_paged_attention(&self) -> bool {
2302 true
2303 }
2304 fn prefixer(&self) -> Arc<dyn VisionPromptPrefixer> {
2305 Arc::new(MiniCpmOPrefixer)
2306 }
2307}
2308
2309impl IsqModelLoader for MiniCpmOLoader {
2310 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2311 Ok(vec![
2312 Regex::new(r"llm.lm_head\.(weight|bias)$")?,
2313 Regex::new(r"llm.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2315 Regex::new(r"llm.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2316 Regex::new(r"llm.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2317 Regex::new(r"llm.layers\.(\d+)\.self_attn\.dense\.(weight|bias)$")?,
2318 Regex::new(r"llm.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2320 Regex::new(r"llm.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2321 Regex::new(r"llm.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2322 ])
2323 }
2324}
2325
2326impl DeviceMappedModelLoader for MiniCpmOLoader {
2327 fn mapped_max_act_size_elems(
2328 &self,
2329 config: &str,
2330 params: &AutoDeviceMapParams,
2331 _prompt_chunksize: usize,
2332 ) -> Result<usize> {
2333 let AutoDeviceMapParams::Vision {
2334 max_seq_len,
2335 max_batch_size,
2336 max_image_shape: _,
2337 max_num_images,
2338 } = params
2339 else {
2340 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2341 };
2342
2343 let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2344
2345 let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2346 let img_seq_len = (num_patches + 1) * max_num_images;
2347
2348 let max_text_attn = {
2349 let max_seq_len = img_seq_len + max_seq_len;
2351 max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len
2352 };
2353
2354 Ok(max_text_attn)
2355 }
2356
2357 fn non_mapped_max_act_size_elems(
2358 &self,
2359 config: &str,
2360 params: &AutoDeviceMapParams,
2361 ) -> Result<usize> {
2362 let AutoDeviceMapParams::Vision {
2363 max_seq_len: _,
2364 max_batch_size,
2365 max_image_shape: _,
2366 max_num_images,
2367 } = params
2368 else {
2369 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2370 };
2371
2372 let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2373
2374 let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2375 let img_seq_len = num_patches + 1;
2376
2377 let max_vision_attn = {
2378 let images_factor = 5;
2380
2381 (max_batch_size * images_factor * max_num_images)
2382 * cfg.vision_config.num_attention_heads
2383 * img_seq_len
2384 * img_seq_len
2385 };
2386
2387 Ok(max_vision_attn)
2388 }
2389
2390 fn non_mapped_size_in_bytes(
2391 &self,
2392 config: &str,
2393 dtype: DType,
2394 weight_pack_factor: usize,
2395 ) -> Result<usize> {
2396 let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2397 let text_elems = {
2398 let cfg = &cfg.text_config;
2399
2400 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2401 let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2402 let norm = cfg.hidden_size;
2403 embed_tokens + lm_head + norm
2404 };
2405
2406 let vision_transformer = {
2407 let cfg = &cfg.vision_config;
2408
2409 let post_layernorm = cfg.hidden_size;
2410
2411 let conv_config = Conv2dConfig {
2412 stride: cfg.patch_size,
2413 ..Default::default()
2414 };
2415 let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
2416 * cfg.patch_size
2417 * cfg.patch_size;
2418
2419 let num_patches_per_side = cfg.image_size / cfg.patch_size;
2420 let num_patches = num_patches_per_side.pow(2);
2421 let position_embedding = num_patches * cfg.hidden_size;
2422
2423 let layer_elems = {
2424 let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2425 let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2426
2427 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
2428 let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
2429
2430 let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2431 let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2432 let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2433 let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2434
2435 layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
2436 };
2437
2438 post_layernorm
2439 + patch_embedding
2440 + position_embedding
2441 + layer_elems * cfg.num_hidden_layers
2442 };
2443
2444 let elems = text_elems + vision_transformer;
2445
2446 Ok(elems * dtype.size_in_bytes())
2447 }
2448
2449 fn layer_sizes_in_bytes(
2450 &self,
2451 config: &str,
2452 dtype: DType,
2453 weight_pack_factor: usize,
2454 ) -> Result<Vec<usize>> {
2455 let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2456 let cfg = cfg.text_config;
2457 let per_layer_elems = {
2458 let input_layernorm = cfg.hidden_size;
2459 let post_attention_layernorm = cfg.hidden_size;
2460
2461 let size_in = cfg.hidden_size;
2462 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2463 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2464 let q_proj = size_in * size_q / weight_pack_factor;
2465 let k_proj = size_in * size_kv / weight_pack_factor;
2466 let v_proj = size_in * size_kv / weight_pack_factor;
2467 let o_proj = size_q * size_in / weight_pack_factor;
2468
2469 let h_size = cfg.hidden_size;
2470 let i_size = cfg.intermediate_size;
2471 let gate_proj = h_size * i_size / weight_pack_factor;
2472 let up_proj = h_size * i_size / weight_pack_factor;
2473 let down_proj = i_size * h_size / weight_pack_factor;
2474
2475 input_layernorm
2476 + post_attention_layernorm
2477 + q_proj
2478 + k_proj
2479 + v_proj
2480 + o_proj
2481 + gate_proj
2482 + up_proj
2483 + down_proj
2484 };
2485 Ok(vec![
2486 per_layer_elems * dtype.size_in_bytes();
2487 cfg.num_hidden_layers
2488 ])
2489 }
2490
2491 fn num_layers(&self, config: &str) -> Result<usize> {
2492 let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2493 Ok(cfg.text_config.num_hidden_layers)
2494 }
2495 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2496 let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2497 let cfg = &cfg.text_config;
2498
2499 let cfg = ModelConfigMetadata {
2500 max_seq_len: cfg.max_position_embeddings,
2501 num_layers: cfg.num_hidden_layers,
2502 hidden_size: cfg.hidden_size,
2503 num_kv_heads: cfg.num_key_value_heads,
2504 num_attn_heads: cfg.num_attention_heads,
2505 sliding_window: None,
2506 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2507 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2508 };
2509
2510 Ok(Box::new(cfg))
2511 }
2512}
2513
2514pub struct Phi4MMLoader;
2520
2521pub struct Phi4MMPrefixer;
2522
2523impl VisionPromptPrefixer for Phi4MMPrefixer {
2524 fn prefix_image(&self, image_index: usize, prompt: &str) -> String {
2525 format!("<|image_{}|>{prompt}", image_index + 1)
2527 }
2528}
2529
2530impl VisionModelLoader for Phi4MMLoader {
2531 fn load(
2532 &self,
2533 config: &str,
2534 use_flash_attn: bool,
2535 vb: ShardedVarBuilder,
2536 normal_loading_metadata: NormalLoadingMetadata,
2537 attention_mechanism: AttentionImplementation,
2538 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2539 let mut config: Phi4MMConfig = serde_json::from_str(config)?;
2540 config.use_flash_attn = use_flash_attn;
2541 Ok(Box::new(Phi4MMModel::new(
2542 &config,
2543 vb,
2544 self.is_gptx(),
2545 normal_loading_metadata,
2546 attention_mechanism,
2547 )?))
2548 }
2549 fn is_gptx(&self) -> bool {
2550 true
2551 }
2552 fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result<Box<dyn Debug>> {
2553 let mut config: Phi4MMConfig = serde_json::from_str(config)?;
2554 config.use_flash_attn = use_flash_attn;
2555 Ok(Box::new(config))
2556 }
2557 fn get_processor(
2558 &self,
2559 _model_config: &str,
2560 processor_config: Option<ProcessorConfig>,
2561 preprocessor_config: PreProcessorConfig,
2562 _max_edge: Option<u32>,
2563 ) -> Arc<dyn Processor + Send + Sync> {
2564 Phi4MMProcessor::new_processor(processor_config, preprocessor_config)
2565 }
2566 fn supports_paged_attention(&self) -> bool {
2567 true
2568 }
2569 fn prefixer(&self) -> Arc<dyn VisionPromptPrefixer> {
2570 Arc::new(Phi4MMPrefixer)
2571 }
2572}
2573
2574impl IsqModelLoader for Phi4MMLoader {
2575 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2576 Ok(vec![
2577 Regex::new(r"lm_head\.(weight|bias)$")?,
2578 Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?,
2580 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2581 Regex::new(r"layers\.(\d+)\.mlp\.gate_up_proj\.(weight|bias)$")?,
2583 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2584 ])
2585 }
2586}
2587
2588impl DeviceMappedModelLoader for Phi4MMLoader {
2589 fn mapped_max_act_size_elems(
2590 &self,
2591 config: &str,
2592 params: &AutoDeviceMapParams,
2593 _prompt_chunksize: usize,
2594 ) -> Result<usize> {
2595 let AutoDeviceMapParams::Vision {
2597 max_seq_len,
2598 max_batch_size,
2599 max_image_shape: _,
2600 max_num_images,
2601 } = params
2602 else {
2603 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2604 };
2605
2606 let cfg: Phi4MMConfig = serde_json::from_str(config)?;
2607
2608 let vcfg = &PHI4_MM_VISION_CFG;
2609
2610 let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
2611 let img_seq_len = (num_patches + 1) * max_num_images;
2612
2613 let max_text_attn = {
2614 let max_seq_len = img_seq_len + max_seq_len;
2616 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
2617 };
2618
2619 Ok(max_text_attn)
2620 }
2621
2622 fn non_mapped_max_act_size_elems(
2623 &self,
2624 _config: &str,
2625 params: &AutoDeviceMapParams,
2626 ) -> Result<usize> {
2627 let AutoDeviceMapParams::Vision {
2628 max_seq_len: _,
2629 max_batch_size,
2630 max_image_shape,
2631 max_num_images,
2632 } = params
2633 else {
2634 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2635 };
2636
2637 let vcfg = &PHI4_MM_VISION_CFG;
2638
2639 let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
2640 let img_seq_len = num_patches + 1;
2641
2642 let max_batch_size = max_batch_size
2643 * (max_image_shape
2644 .0
2645 .div_ceil(phi4::inputs_processor::DYHD_BASE_RESOLUTION)
2646 * max_image_shape
2647 .1
2648 .div_ceil(phi4::inputs_processor::DYHD_BASE_RESOLUTION)
2649 + 1);
2650
2651 let max_vision_attn = (max_batch_size * max_num_images)
2652 * vcfg.num_attention_heads
2653 * img_seq_len
2654 * img_seq_len;
2655 let max_qkv = 3
2656 * (max_batch_size
2657 * vcfg.num_attention_heads
2658 * img_seq_len
2659 * (vcfg.hidden_size / vcfg.num_attention_heads));
2660
2661 Ok(max_vision_attn + max_qkv)
2662 }
2663
2664 fn non_mapped_size_in_bytes(
2665 &self,
2666 config: &str,
2667 dtype: DType,
2668 weight_pack_factor: usize,
2669 ) -> Result<usize> {
2670 let cfg: Phi4MMConfig = serde_json::from_str(config)?;
2671 let elems = {
2672 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2673 let lm_head = if !cfg.tie_word_embeddings {
2674 cfg.hidden_size * cfg.vocab_size
2675 } else {
2676 0
2677 };
2678 let norm = cfg.hidden_size;
2679
2680 let image_embed = if let Some(img_embed) = &cfg.embd_layer.image_embd_layer {
2681 let projection_cls = img_embed
2682 .projection_cls
2683 .clone()
2684 .unwrap_or("linear".to_string());
2685 let with_learnable_separator = img_embed.with_learnable_separator.unwrap_or(false);
2686 let use_hd_transform = img_embed.use_hd_transform.unwrap_or(false);
2687 let image_dim_out = PHI4_MM_VISION_CFG.hidden_size;
2688
2689 let proj = match (projection_cls.as_str(), use_hd_transform) {
2690 ("linear", _) => image_dim_out * cfg.hidden_size + cfg.hidden_size,
2691 ("mlp", true) => {
2692 let a = (image_dim_out * 4) * cfg.hidden_size + cfg.hidden_size;
2693 let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2694 a + b
2695 }
2696 ("mlp", false) => {
2697 let a = image_dim_out * cfg.hidden_size + cfg.hidden_size;
2698 let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2699 a + b
2700 }
2701 _ => {
2702 anyhow::bail!("projection_cls=`{projection_cls}` not implemented.");
2703 }
2704 };
2705
2706 let (glb_gn, sub_gn) = if with_learnable_separator {
2707 let glb_gn = image_dim_out * 4;
2708 let sub_gn = image_dim_out * 4;
2709 (glb_gn, sub_gn)
2710 } else {
2711 (0, 0)
2712 };
2713
2714 let vision_transformer = {
2715 let cfg = &PHI4_MM_VISION_CFG;
2716
2717 let post_layernorm = cfg.hidden_size;
2718
2719 let conv_config = Conv2dConfig {
2720 stride: cfg.patch_size,
2721 ..Default::default()
2722 };
2723 let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
2724 * cfg.patch_size
2725 * cfg.patch_size;
2726
2727 let num_patches_per_side = cfg.image_size / cfg.patch_size;
2728 let num_patches = num_patches_per_side.pow(2);
2729 let position_embedding = num_patches * cfg.hidden_size;
2730
2731 let layer_elems = {
2732 let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2733 let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2734
2735 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
2736 let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
2737
2738 let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2739 let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2740 let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2741 let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2742
2743 layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
2744 };
2745
2746 post_layernorm
2747 + patch_embedding
2748 + position_embedding
2749 + layer_elems * cfg.num_hidden_layers
2750 };
2751
2752 proj + glb_gn + sub_gn + vision_transformer
2753 } else {
2754 0
2755 };
2756
2757 embed_tokens + lm_head + norm + image_embed
2758 };
2759
2760 Ok(elems * dtype.size_in_bytes())
2761 }
2762
2763 fn layer_sizes_in_bytes(
2764 &self,
2765 config: &str,
2766 dtype: DType,
2767 weight_pack_factor: usize,
2768 ) -> Result<Vec<usize>> {
2769 let cfg: Phi4MMConfig = serde_json::from_str(config)?;
2770 let per_layer_elems = {
2771 let input_layernorm = cfg.hidden_size;
2772 let post_attention_layernorm = cfg.hidden_size;
2773
2774 let size_in = cfg.hidden_size;
2775 let head_dim = cfg.head_dim();
2776 let op_size = head_dim * head_dim + 2 * cfg.num_key_value_heads() * head_dim;
2777 let qkv_proj = size_in * op_size / weight_pack_factor;
2778 let o_proj = (cfg.num_attention_heads * head_dim) * size_in / weight_pack_factor;
2779
2780 let h_size = cfg.hidden_size;
2781 let i_size = cfg.intermediate_size;
2782 let gate_up_proj = h_size * (2 * i_size) / weight_pack_factor;
2783 let down_proj = h_size * i_size / weight_pack_factor;
2784
2785 input_layernorm
2786 + post_attention_layernorm
2787 + qkv_proj
2788 + o_proj
2789 + gate_up_proj
2790 + down_proj
2791 };
2792 Ok(vec![
2793 per_layer_elems * dtype.size_in_bytes();
2794 cfg.num_hidden_layers
2795 ])
2796 }
2797
2798 fn num_layers(&self, config: &str) -> Result<usize> {
2799 let cfg: Phi4MMConfig = serde_json::from_str(config)?;
2800 Ok(cfg.num_hidden_layers)
2801 }
2802
2803 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2804 let cfg: Phi4MMConfig = serde_json::from_str(config)?;
2805
2806 let cfg = ModelConfigMetadata {
2807 max_seq_len: cfg.max_position_embeddings,
2808 num_layers: cfg.num_hidden_layers,
2809 hidden_size: cfg.hidden_size,
2810 num_kv_heads: cfg.num_key_value_heads(),
2811 num_attn_heads: cfg.num_attention_heads,
2812 sliding_window: cfg.sliding_window,
2813 k_head_dim: cfg.head_dim(),
2814 v_head_dim: cfg.head_dim(),
2815 };
2816
2817 Ok(Box::new(cfg))
2818 }
2819
2820 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
2821 Some(vec![NonMappedSubModel::Vision])
2822 }
2823}
2824
2825pub struct Qwen2_5VLLoader;
2831
2832pub struct Qwen2_5VLPrefixer;
2833
2834impl VisionPromptPrefixer for Qwen2_5VLPrefixer {
2835 fn prefix_image(&self, _image_index: usize, prompt: &str) -> String {
2836 format!(
2837 "{}{}{}{prompt}",
2838 Qwen2_5VLProcessor::VISION_START,
2839 Qwen2_5VLProcessor::IMAGE_PAD,
2840 Qwen2_5VLProcessor::VISION_END
2841 )
2842 }
2843}
2844
2845impl VisionModelLoader for Qwen2_5VLLoader {
2846 fn load(
2847 &self,
2848 config: &str,
2849 _use_flash_attn: bool,
2850 vb: ShardedVarBuilder,
2851 normal_loading_metadata: NormalLoadingMetadata,
2852 attention_mechanism: AttentionImplementation,
2853 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2854 let config: Qwen2_5VLConfig = serde_json::from_str(config)?;
2855 Ok(Box::new(Qwen2_5VLModel::new(
2856 &config,
2857 vb,
2858 self.is_gptx(),
2859 normal_loading_metadata,
2860 attention_mechanism,
2861 )?))
2862 }
2863 fn is_gptx(&self) -> bool {
2864 true
2865 }
2866 fn get_config_repr(&self, config: &str, _use_flash_attn: bool) -> Result<Box<dyn Debug>> {
2867 let config: Qwen2_5VLConfig = serde_json::from_str(config)?;
2868 Ok(Box::new(config))
2869 }
2870 fn get_processor(
2871 &self,
2872 _model_config: &str,
2873 _processor_config: Option<ProcessorConfig>,
2874 _preprocessor_config: PreProcessorConfig,
2875 max_edge: Option<u32>,
2876 ) -> Arc<dyn Processor + Send + Sync> {
2877 Arc::new(Qwen2_5VLProcessor::new(max_edge))
2878 }
2879 fn supports_paged_attention(&self) -> bool {
2880 false
2881 }
2882 fn prefixer(&self) -> Arc<dyn VisionPromptPrefixer> {
2883 Arc::new(Qwen2_5VLPrefixer)
2884 }
2885}
2886
2887impl IsqModelLoader for Qwen2_5VLLoader {
2888 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2889 Ok(vec![
2890 Regex::new(r"lm_head\.(weight|bias)$")?,
2891 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2893 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2894 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2895 Regex::new(r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$")?,
2896 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2898 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2899 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2900 ])
2901 }
2902}
2903
2904impl DeviceMappedModelLoader for Qwen2_5VLLoader {
2905 fn mapped_max_act_size_elems(
2906 &self,
2907 config: &str,
2908 params: &AutoDeviceMapParams,
2909 _prompt_chunksize: usize,
2910 ) -> Result<usize> {
2911 let AutoDeviceMapParams::Vision {
2912 max_seq_len,
2913 max_batch_size,
2914 max_image_shape,
2915 max_num_images,
2916 } = params
2917 else {
2918 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2919 };
2920
2921 let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
2922
2923 let img_seq_len = {
2924 let cfg = &cfg.vision_config;
2925 let grid_t = max_num_images / cfg.temporal_patch_size;
2926 let grid_h = max_image_shape.0 / cfg.patch_size;
2927 let grid_w = max_image_shape.1 / cfg.patch_size;
2928 grid_t * grid_h * grid_w
2929 };
2930 let img_seq_len = img_seq_len * max_num_images;
2931
2932 let max_text_attn = {
2933 let max_seq_len = img_seq_len + max_seq_len;
2935 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
2936 };
2937
2938 Ok(max_text_attn)
2939 }
2940
2941 fn non_mapped_max_act_size_elems(
2942 &self,
2943 config: &str,
2944 params: &AutoDeviceMapParams,
2945 ) -> Result<usize> {
2946 let AutoDeviceMapParams::Vision {
2947 max_seq_len: _,
2948 max_batch_size,
2949 max_image_shape,
2950 max_num_images,
2951 } = params
2952 else {
2953 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2954 };
2955
2956 let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
2957
2958 let img_seq_len = {
2959 let cfg = &cfg.vision_config;
2960 let grid_t = max_num_images / cfg.temporal_patch_size;
2961 let grid_h = max_image_shape.0 / cfg.patch_size;
2962 let grid_w = max_image_shape.1 / cfg.patch_size;
2963 grid_t * grid_h * grid_w
2964 };
2965
2966 let max_vision_attn = {
2967 let cfg = &cfg.vision_config;
2968 (max_batch_size * max_num_images) * cfg.num_heads * img_seq_len * img_seq_len
2969 };
2970
2971 Ok(max_vision_attn)
2972 }
2973
2974 fn non_mapped_size_in_bytes(
2975 &self,
2976 config: &str,
2977 dtype: DType,
2978 weight_pack_factor: usize,
2979 ) -> Result<usize> {
2980 let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
2981 let text_elems = {
2982 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2983 let lm_head = if !cfg.tie_word_embeddings {
2984 cfg.hidden_size * cfg.vocab_size
2985 } else {
2986 0
2987 };
2988 let norm = cfg.hidden_size;
2989 embed_tokens + lm_head + norm
2990 };
2991
2992 let patch_merger = {
2993 let cfg = &cfg.vision_config;
2994 let hidden_size = cfg.hidden_size * cfg.spatial_merge_size.pow(2);
2995
2996 let mlp0 = hidden_size * hidden_size + hidden_size;
2997 let mlp2 = hidden_size * cfg.hidden_size + cfg.hidden_size;
2998
2999 let ln_q = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3000
3001 mlp0 + mlp2 + ln_q
3002 };
3003
3004 let patch_embed = {
3005 let cfg = &cfg.vision_config;
3006 let conv_cfg = Conv3dConfig {
3007 stride: cfg.patch_size,
3008 ..Default::default()
3009 };
3010 let kernel_sizes = [cfg.temporal_patch_size, cfg.patch_size, cfg.patch_size];
3011 cfg.in_chans * cfg.hidden_size / conv_cfg.groups
3012 * kernel_sizes[0]
3013 * kernel_sizes[1]
3014 * kernel_sizes[2]
3015 };
3016
3017 let encoder_layer = {
3018 let cfg = &cfg.vision_config;
3019 let norm1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3020 let norm2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3021
3022 #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
3023 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
3024 let fc2 = cfg.hidden_size * cfg.intermediate_size + cfg.hidden_size;
3025
3026 let qkv = cfg.hidden_size * cfg.hidden_size * 3 + cfg.hidden_size * 3;
3027 let out = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3028
3029 norm1 + norm2 + fc1 + fc2 + qkv + out
3030 };
3031
3032 let elems =
3033 text_elems + patch_merger + patch_embed + encoder_layer * cfg.vision_config.depth;
3034
3035 Ok(elems * dtype.size_in_bytes())
3036 }
3037
3038 fn layer_sizes_in_bytes(
3039 &self,
3040 config: &str,
3041 dtype: DType,
3042 weight_pack_factor: usize,
3043 ) -> Result<Vec<usize>> {
3044 let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3045 let per_layer_elems = {
3046 let input_layernorm = cfg.hidden_size;
3047 let post_attention_layernorm = cfg.hidden_size;
3048
3049 let size_in = cfg.hidden_size;
3050 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
3051 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
3052 let q_proj = size_in * size_q / weight_pack_factor + size_q;
3053 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
3054 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
3055 let o_proj = size_q * size_in / weight_pack_factor;
3056
3057 let h_size = cfg.hidden_size;
3058 let i_size = cfg.intermediate_size;
3059 let gate_proj = h_size * i_size / weight_pack_factor;
3060 let up_proj = h_size * i_size / weight_pack_factor;
3061 let down_proj = i_size * h_size / weight_pack_factor;
3062
3063 input_layernorm
3064 + post_attention_layernorm
3065 + q_proj
3066 + k_proj
3067 + v_proj
3068 + o_proj
3069 + gate_proj
3070 + up_proj
3071 + down_proj
3072 };
3073 Ok(vec![
3074 per_layer_elems * dtype.size_in_bytes();
3075 cfg.num_hidden_layers
3076 ])
3077 }
3078
3079 fn num_layers(&self, config: &str) -> Result<usize> {
3080 let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3081 Ok(cfg.num_hidden_layers)
3082 }
3083
3084 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3085 let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3086
3087 let cfg = ModelConfigMetadata {
3088 max_seq_len: cfg.max_position_embeddings,
3089 num_layers: cfg.num_hidden_layers,
3090 hidden_size: cfg.hidden_size,
3091 num_kv_heads: cfg.num_key_value_heads,
3092 num_attn_heads: cfg.num_attention_heads,
3093 sliding_window: cfg.sliding_window,
3094 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3095 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3096 };
3097
3098 Ok(Box::new(cfg))
3099 }
3100
3101 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
3102 Some(vec![NonMappedSubModel::Vision])
3103 }
3104}
3105
3106pub struct Gemma3Loader;
3112
3113pub struct Gemma3Prefixer;
3114
3115impl VisionPromptPrefixer for Gemma3Prefixer {
3116 fn prefix_image(&self, _image_index: usize, prompt: &str) -> String {
3117 prompt.to_string()
3118 }
3119}
3120
3121impl VisionModelLoader for Gemma3Loader {
3122 fn load(
3123 &self,
3124 config: &str,
3125 _use_flash_attn: bool,
3126 vb: ShardedVarBuilder,
3127 normal_loading_metadata: NormalLoadingMetadata,
3128 attention_mechanism: AttentionImplementation,
3129 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
3130 let config: Gemma3Config = serde_json::from_str(config)?;
3131 Ok(Box::new(Gemma3Model::new(
3132 &config,
3133 vb,
3134 self.is_gptx(),
3135 normal_loading_metadata,
3136 attention_mechanism,
3137 )?))
3138 }
3139 fn is_gptx(&self) -> bool {
3140 true
3141 }
3142 fn get_config_repr(&self, config: &str, _use_flash_attn: bool) -> Result<Box<dyn Debug>> {
3143 let config: Gemma3Config = serde_json::from_str(config)?;
3144 Ok(Box::new(config))
3145 }
3146 fn get_processor(
3147 &self,
3148 config: &str,
3149 processor_config: Option<ProcessorConfig>,
3150 _preprocessor_config: PreProcessorConfig,
3151 _max_edge: Option<u32>,
3152 ) -> Arc<dyn Processor + Send + Sync> {
3153 let config: Gemma3Config = serde_json::from_str(config).unwrap();
3154 Arc::new(Gemma3Processor::new(
3156 processor_config.unwrap_or_default(),
3157 matches!(config, Gemma3Config::WithVision { .. }),
3158 ))
3159 }
3160 fn supports_paged_attention(&self) -> bool {
3161 true
3162 }
3163 fn prefixer(&self) -> Arc<dyn VisionPromptPrefixer> {
3164 Arc::new(Gemma3Prefixer)
3165 }
3166}
3167
3168impl IsqModelLoader for Gemma3Loader {
3169 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3170 Ok(vec![
3171 Regex::new(r"lm_head\.(weight|bias)$")?,
3172 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3174 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3175 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3176 Regex::new(r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$")?,
3177 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3179 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3180 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3181 ])
3182 }
3183}
3184
3185impl DeviceMappedModelLoader for Gemma3Loader {
3186 fn mapped_max_act_size_elems(
3187 &self,
3188 config: &str,
3189 params: &AutoDeviceMapParams,
3190 prompt_chunksize: usize,
3191 ) -> Result<usize> {
3192 let AutoDeviceMapParams::Vision {
3193 max_seq_len,
3194 max_batch_size,
3195 max_image_shape: _,
3196 max_num_images,
3197 } = params
3198 else {
3199 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3200 };
3201
3202 let cfg: Gemma3Config = serde_json::from_str(config)?;
3203
3204 match cfg {
3205 Gemma3Config::Text(text_config) => Ok(max_batch_size
3206 * text_config.num_attention_heads
3207 * prompt_chunksize
3208 * prompt_chunksize),
3209 Gemma3Config::WithVision {
3210 text_config,
3211 vision_config,
3212 ..
3213 } => {
3214 let num_patches = (vision_config.image_size / vision_config.patch_size).pow(2);
3215 let img_seq_len = (num_patches + 1) * max_num_images;
3216
3217 let max_text_attn = {
3218 let max_seq_len = img_seq_len + *max_seq_len;
3220 max_batch_size * text_config.num_attention_heads * max_seq_len * max_seq_len
3221 };
3222 Ok(max_text_attn)
3223 }
3224 }
3225 }
3226
3227 fn non_mapped_max_act_size_elems(
3228 &self,
3229 config: &str,
3230 params: &AutoDeviceMapParams,
3231 ) -> Result<usize> {
3232 let AutoDeviceMapParams::Vision {
3233 max_seq_len: _,
3234 max_batch_size,
3235 max_image_shape: _,
3236 max_num_images,
3237 } = params
3238 else {
3239 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3240 };
3241
3242 let cfg: Gemma3Config = serde_json::from_str(config)?;
3243
3244 match cfg {
3245 Gemma3Config::WithVision { vision_config, .. } => {
3246 let num_patches = (vision_config.image_size / vision_config.patch_size).pow(2);
3247 let img_seq_len = num_patches + 1;
3248
3249 let max_vision_attn = {
3250 (max_batch_size * max_num_images)
3251 * vision_config.num_attention_heads
3252 * img_seq_len
3253 * img_seq_len
3254 };
3255
3256 Ok(max_vision_attn)
3257 }
3258 Gemma3Config::Text(_) => Ok(0),
3259 }
3260 }
3261
3262 fn non_mapped_size_in_bytes(
3263 &self,
3264 config: &str,
3265 dtype: DType,
3266 weight_pack_factor: usize,
3267 ) -> Result<usize> {
3268 let cfg: Gemma3Config = serde_json::from_str(config)?;
3269
3270 let text_elems = {
3271 let cfg = match &cfg {
3272 Gemma3Config::Text(cfg) => cfg,
3273 Gemma3Config::WithVision { text_config, .. } => text_config,
3274 };
3275 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3276 let lm_head = if !cfg.tie_word_embeddings {
3277 cfg.hidden_size * cfg.vocab_size
3278 } else {
3279 0
3280 };
3281 let norm = cfg.hidden_size;
3282 embed_tokens + lm_head + norm
3283 };
3284
3285 let vision_transformer = if let Gemma3Config::WithVision {
3286 vision_config: cfg, ..
3287 } = &cfg
3288 {
3289 let post_layernorm = cfg.hidden_size;
3290
3291 let conv_config = Conv2dConfig {
3292 stride: cfg.patch_size,
3293 ..Default::default()
3294 };
3295 let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
3296 * cfg.patch_size
3297 * cfg.patch_size;
3298
3299 let num_patches_per_side = cfg.image_size / cfg.patch_size;
3300 let num_patches = num_patches_per_side.pow(2);
3301 let position_embedding = num_patches * cfg.hidden_size;
3302
3303 let layer_elems = {
3304 let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3305 let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3306
3307 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
3308 let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
3309
3310 let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3311 let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3312 let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3313 let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3314
3315 layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
3316 };
3317
3318 post_layernorm
3319 + patch_embedding
3320 + position_embedding
3321 + layer_elems * cfg.num_hidden_layers
3322 } else {
3323 0
3324 };
3325
3326 let elems = text_elems + vision_transformer;
3327
3328 Ok(elems * dtype.size_in_bytes())
3329 }
3330
3331 fn layer_sizes_in_bytes(
3332 &self,
3333 config: &str,
3334 dtype: DType,
3335 weight_pack_factor: usize,
3336 ) -> Result<Vec<usize>> {
3337 let cfg: Gemma3Config = serde_json::from_str(config)?;
3338
3339 let txt_cfg = match &cfg {
3340 Gemma3Config::Text(cfg) => cfg,
3341 Gemma3Config::WithVision { text_config, .. } => text_config,
3342 };
3343 let per_layer_elems = {
3344 let cfg = txt_cfg;
3345
3346 let input_layernorm = cfg.hidden_size;
3347 let post_attention_layernorm = cfg.hidden_size;
3348
3349 let size_in = cfg.hidden_size;
3350 let size_q = cfg.head_dim * cfg.num_attention_heads;
3351 let size_kv = cfg.head_dim * cfg.num_key_value_heads;
3352 let q_proj =
3353 size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
3354 let k_proj =
3355 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
3356 let v_proj =
3357 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
3358 let o_proj =
3359 size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
3360
3361 let h_size = cfg.hidden_size;
3362 let i_size = cfg.intermediate_size;
3363 let gate_proj = h_size * i_size / weight_pack_factor;
3364 let up_proj = h_size * i_size / weight_pack_factor;
3365 let down_proj = i_size * h_size / weight_pack_factor;
3366
3367 input_layernorm
3368 + post_attention_layernorm
3369 + q_proj
3370 + k_proj
3371 + v_proj
3372 + o_proj
3373 + gate_proj
3374 + up_proj
3375 + down_proj
3376 };
3377 Ok(vec![
3378 per_layer_elems * dtype.size_in_bytes();
3379 txt_cfg.num_hidden_layers
3380 ])
3381 }
3382
3383 fn num_layers(&self, config: &str) -> Result<usize> {
3384 let cfg: Gemma3Config = serde_json::from_str(config)?;
3385
3386 let txt_cfg = match &cfg {
3387 Gemma3Config::Text(cfg) => cfg,
3388 Gemma3Config::WithVision { text_config, .. } => text_config,
3389 };
3390
3391 Ok(txt_cfg.num_hidden_layers)
3392 }
3393
3394 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3395 let cfg: Gemma3Config = serde_json::from_str(config)?;
3396
3397 let cfg = match &cfg {
3398 Gemma3Config::Text(cfg) => cfg,
3399 Gemma3Config::WithVision { text_config, .. } => text_config,
3400 };
3401
3402 let cfg = ModelConfigMetadata {
3403 max_seq_len: cfg.max_position_embeddings,
3404 num_layers: cfg.num_hidden_layers,
3405 hidden_size: cfg.hidden_size,
3406 num_kv_heads: cfg.num_key_value_heads,
3407 num_attn_heads: cfg.num_attention_heads,
3408 sliding_window: None, k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3410 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3411 };
3412
3413 Ok(Box::new(cfg))
3414 }
3415
3416 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
3417 Some(vec![NonMappedSubModel::Vision])
3418 }
3419}
3420
3421pub struct Mistral3Loader;
3427
3428pub struct Mistral3Prefixer;
3429
3430impl VisionPromptPrefixer for Mistral3Prefixer {
3431 fn prefix_image(&self, _image_index: usize, prompt: &str) -> String {
3432 prompt.to_string()
3433 }
3434}
3435
3436impl VisionModelLoader for Mistral3Loader {
3437 fn load(
3438 &self,
3439 config: &str,
3440 use_flash_attn: bool,
3441 vb: ShardedVarBuilder,
3442 normal_loading_metadata: NormalLoadingMetadata,
3443 attention_mechanism: AttentionImplementation,
3444 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
3445 let mut config: Mistral3Config = serde_json::from_str(config)?;
3446 config.text_config.use_flash_attn = use_flash_attn;
3447 Ok(Box::new(Mistral3Model::new(
3448 &config,
3449 vb,
3450 self.is_gptx(),
3451 normal_loading_metadata,
3452 attention_mechanism,
3453 )?))
3454 }
3455 fn is_gptx(&self) -> bool {
3456 true
3457 }
3458 fn get_config_repr(&self, config: &str, _use_flash_attn: bool) -> Result<Box<dyn Debug>> {
3459 let config: Mistral3Config = serde_json::from_str(config)?;
3460 Ok(Box::new(config))
3461 }
3462 fn get_processor(
3463 &self,
3464 _model_config: &str,
3465 processor_config: Option<ProcessorConfig>,
3466 _preprocessor_config: PreProcessorConfig,
3467 _max_edge: Option<u32>,
3468 ) -> Arc<dyn Processor + Send + Sync> {
3469 Arc::new(Mistral3Processor::new(processor_config.unwrap_or_default()))
3470 }
3471 fn supports_paged_attention(&self) -> bool {
3472 true
3473 }
3474 fn prefixer(&self) -> Arc<dyn VisionPromptPrefixer> {
3475 Arc::new(Mistral3Prefixer)
3476 }
3477}
3478
3479impl IsqModelLoader for Mistral3Loader {
3480 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3481 Ok(vec![
3482 Regex::new(r"lm_head\.(weight|bias)$")?,
3483 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3485 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3486 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3487 Regex::new(r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$")?,
3488 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3490 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3491 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3492 ])
3493 }
3494}
3495
3496#[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
3497impl DeviceMappedModelLoader for Mistral3Loader {
3498 fn mapped_max_act_size_elems(
3499 &self,
3500 config: &str,
3501 params: &AutoDeviceMapParams,
3502 _prompt_chunksize: usize,
3503 ) -> Result<usize> {
3504 let cfg: Mistral3Config = serde_json::from_str(config)?;
3505 let vcfg = &cfg.vision_config;
3506 let tcfg = &cfg.text_config;
3507
3508 let AutoDeviceMapParams::Vision {
3509 max_seq_len,
3510 max_batch_size,
3511 max_image_shape: (mut height, mut width),
3512 max_num_images,
3513 } = params
3514 else {
3515 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3516 };
3517
3518 let img_seq_len = {
3519 let (max_height, max_width) = (1540, 1540);
3523 let ratio = (height as f64 / max_height as f64).max(width as f64 / max_width as f64);
3524 if ratio > 1. {
3525 height = (height as f64 / ratio).floor() as usize;
3526 width = (width as f64 / ratio).floor() as usize;
3527 }
3528
3529 let num_height_tokens = (height - 1) / vcfg.patch_size + 1;
3530 let num_width_tokens = (width - 1) / vcfg.patch_size + 1;
3531
3532 height = num_height_tokens * vcfg.patch_size;
3533 width = num_width_tokens * vcfg.patch_size;
3534
3535 let num_height_tokens = height / vcfg.patch_size;
3536 let num_width_tokens = width / vcfg.patch_size;
3537
3538 (num_width_tokens + 1) * num_height_tokens
3539 };
3540
3541 let max_seq_len = img_seq_len * max_num_images + *max_seq_len;
3543 Ok(max_batch_size * tcfg.num_attention_heads * max_seq_len * max_seq_len)
3544 }
3545
3546 fn non_mapped_max_act_size_elems(
3547 &self,
3548 config: &str,
3549 params: &AutoDeviceMapParams,
3550 ) -> Result<usize> {
3551 let cfg: Mistral3Config = serde_json::from_str(config)?;
3552 let cfg = &cfg.vision_config;
3553
3554 let AutoDeviceMapParams::Vision {
3555 max_seq_len: _,
3556 max_batch_size,
3557 max_image_shape: (mut height, mut width),
3558 max_num_images,
3559 } = params
3560 else {
3561 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3562 };
3563
3564 let img_seq_len = {
3565 let (max_height, max_width) = (1540, 1540);
3569 let ratio = (height as f64 / max_height as f64).max(width as f64 / max_width as f64);
3570 if ratio > 1. {
3571 height = (height as f64 / ratio).floor() as usize;
3572 width = (width as f64 / ratio).floor() as usize;
3573 }
3574
3575 let num_height_tokens = (height - 1) / cfg.patch_size + 1;
3576 let num_width_tokens = (width - 1) / cfg.patch_size + 1;
3577
3578 height = num_height_tokens * cfg.patch_size;
3579 width = num_width_tokens * cfg.patch_size;
3580
3581 let num_height_tokens = height / cfg.patch_size;
3582 let num_width_tokens = width / cfg.patch_size;
3583
3584 (num_width_tokens + 1) * num_height_tokens
3585 };
3586
3587 Ok((max_batch_size * max_num_images) * cfg.num_attention_heads * img_seq_len * img_seq_len)
3588 }
3589
3590 fn non_mapped_size_in_bytes(
3591 &self,
3592 config: &str,
3593 dtype: DType,
3594 weight_pack_factor: usize,
3595 ) -> Result<usize> {
3596 let cfg: Mistral3Config = serde_json::from_str(config)?;
3597
3598 let text_elems = {
3599 let cfg = &cfg.text_config;
3600
3601 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3602 let lm_head = if !cfg.tie_word_embeddings {
3603 cfg.hidden_size * cfg.vocab_size
3604 } else {
3605 0
3606 };
3607 let norm = cfg.hidden_size;
3608 embed_tokens + lm_head + norm
3609 };
3610
3611 let vision_elems = {
3612 let cfg = &cfg.vision_config;
3613
3614 let patch_embed = {
3615 let conv_cfg = Conv2dConfig {
3616 stride: cfg.patch_size,
3617 ..Default::default()
3618 };
3619 cfg.num_channels * cfg.hidden_size / conv_cfg.groups
3620 * cfg.patch_size
3621 * cfg.patch_size
3622 * cfg.patch_size
3623 };
3624 let ln_pre = cfg.hidden_size;
3625 let vision_layer = {
3626 let attn_norm = cfg.hidden_size;
3627 let ffn_norm = cfg.hidden_size;
3628
3629 let gate = cfg.hidden_size * cfg.intermediate_size;
3630 let up = cfg.hidden_size * cfg.intermediate_size;
3631 let down = cfg.hidden_size * cfg.intermediate_size;
3632
3633 let q = cfg.hidden_size * cfg.hidden_size;
3634 let k = cfg.hidden_size * cfg.hidden_size;
3635 let v = cfg.hidden_size * cfg.hidden_size;
3636 let o = cfg.hidden_size * cfg.hidden_size;
3637
3638 attn_norm + ffn_norm + gate + up + down + q + k + v + o
3639 };
3640
3641 patch_embed + ln_pre + vision_layer * cfg.num_hidden_layers
3642 };
3643
3644 let elems = text_elems + vision_elems;
3645
3646 Ok(elems * dtype.size_in_bytes())
3647 }
3648
3649 fn layer_sizes_in_bytes(
3650 &self,
3651 config: &str,
3652 dtype: DType,
3653 weight_pack_factor: usize,
3654 ) -> Result<Vec<usize>> {
3655 let cfg: Mistral3Config = serde_json::from_str(config)?;
3656 let cfg = &cfg.text_config;
3657
3658 let per_layer_elems = {
3659 let input_layernorm = cfg.hidden_size;
3660 let post_attention_layernorm = cfg.hidden_size;
3661
3662 let size_in = cfg.hidden_size;
3663 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
3664 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
3665 let q_proj = size_in * size_q / weight_pack_factor;
3666 let k_proj = size_in * size_kv / weight_pack_factor;
3667 let v_proj = size_in * size_kv / weight_pack_factor;
3668 let o_proj = size_q * size_in / weight_pack_factor;
3669
3670 let h_size = cfg.hidden_size;
3671 let i_size = cfg.intermediate_size;
3672 let gate_proj = h_size * i_size / weight_pack_factor;
3673 let up_proj = h_size * i_size / weight_pack_factor;
3674 let down_proj = i_size * h_size / weight_pack_factor;
3675
3676 input_layernorm
3677 + post_attention_layernorm
3678 + q_proj
3679 + k_proj
3680 + v_proj
3681 + o_proj
3682 + gate_proj
3683 + up_proj
3684 + down_proj
3685 };
3686 Ok(vec![
3687 per_layer_elems * dtype.size_in_bytes();
3688 cfg.num_hidden_layers
3689 ])
3690 }
3691
3692 fn num_layers(&self, config: &str) -> Result<usize> {
3693 let cfg: Mistral3Config = serde_json::from_str(config)?;
3694 let cfg = &cfg.text_config;
3695 Ok(cfg.num_hidden_layers)
3696 }
3697
3698 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3699 let cfg: Mistral3Config = serde_json::from_str(config)?;
3700 let cfg = &cfg.text_config;
3701
3702 let cfg = ModelConfigMetadata {
3703 max_seq_len: cfg.max_position_embeddings,
3704 num_layers: cfg.num_hidden_layers,
3705 hidden_size: cfg.hidden_size,
3706 num_kv_heads: cfg.num_key_value_heads,
3707 num_attn_heads: cfg.num_attention_heads,
3708 sliding_window: cfg.sliding_window,
3709 k_head_dim: cfg.head_dim(),
3710 v_head_dim: cfg.head_dim(),
3711 };
3712
3713 Ok(Box::new(cfg))
3714 }
3715
3716 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
3717 Some(vec![NonMappedSubModel::Vision])
3718 }
3719}