1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{any::Any, collections::HashMap, sync::Arc};
4
5use candle_core::{Device, Result, Tensor, D};
6use candle_nn::Module;
7use mistralrs_quant::{MatMul, QuantMethod, ReplicatedLayer, ShardedVarBuilder};
8use mm_embedding::Phi4MMImageAudioEmbedding;
9
10use crate::{
11 amoe::AnyMoeBaseModelMixin,
12 attention::SdpaParams,
13 device_map::DeviceMapper,
14 layers::{self, Activation, CausalMasker, Phi4MMRotaryEmbedding, RmsNorm, Sdpa},
15 layers_masker::PastKvLenCache,
16 paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
17 pipeline::{
18 extract_logits,
19 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
20 EitherCache, IsqModel, KvCache, NormalCache, NormalLoadingMetadata, VisionModel,
21 },
22 utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
23};
24
25mod config;
26mod image_embedding;
27pub(crate) mod inputs_processor;
28mod mm_embedding;
29
30pub(crate) use config::Phi4MMConfig;
31pub(crate) use image_embedding::PHI4_MM_VISION_CFG;
32
33struct Attention {
34 qkv_proj: Arc<dyn QuantMethod>,
35 o_proj: Arc<dyn QuantMethod>,
36 num_heads: usize,
37 num_kv_heads: usize,
38 head_dim: usize,
39 rotary_emb: Arc<Phi4MMRotaryEmbedding>,
40 paged_attn: Option<PagedAttention>,
41 sdpa_params: SdpaParams,
42}
43
44impl Attention {
45 fn new(
46 rotary_emb: Arc<Phi4MMRotaryEmbedding>,
47 cfg: &Phi4MMConfig,
48 vb: ShardedVarBuilder,
49 paged_attn: Option<PagedAttention>,
50 ) -> Result<Self> {
51 let num_heads = cfg.num_attention_heads;
52 let num_kv_heads = cfg.num_key_value_heads();
53 let head_dim = cfg.head_dim();
54 let op_size = num_heads * head_dim + 2 * num_kv_heads * head_dim;
55
56 let qkv_proj = mistralrs_quant::linear_no_bias_static_lora(
58 cfg.hidden_size,
59 op_size,
60 cfg.loras(),
61 vb.pp("qkv_proj"),
62 )?;
63
64 let o_proj = mistralrs_quant::linear_no_bias_static_lora(
65 num_heads * head_dim,
66 cfg.hidden_size,
67 cfg.loras(),
68 vb.pp("o_proj"),
69 )?;
70
71 Ok(Self {
72 qkv_proj,
73 o_proj,
74 rotary_emb,
75 num_heads,
76 num_kv_heads,
77 head_dim,
78 paged_attn,
79 sdpa_params: SdpaParams {
80 n_kv_groups: num_heads / num_kv_heads,
81 use_flash_attn: cfg.use_flash_attn,
82 softcap: None,
83 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
84 sliding_window: cfg.sliding_window,
85 },
86 })
87 }
88
89 #[allow(clippy::too_many_arguments)]
90 fn forward(
91 &self,
92 xs: &Tensor,
93 attention_mask: Option<&Tensor>,
94 seqlen_offsets: &[usize],
95 position_ids: &[usize],
96 kv_cache: &mut KvCache,
97 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
98 flash_params: &FlashParams,
99 ) -> Result<Tensor> {
100 let (b_sz, q_len, _) = xs.dims3()?;
101
102 let original_dtype = xs.dtype();
103 let mut xs = xs.clone();
104 if let Some(t) = self.qkv_proj.quantized_act_type() {
105 xs = xs.to_dtype(t)?;
106 }
107 let mut qkv = MatMul.qmethod_matmul(&xs, &*self.qkv_proj)?;
108 if self.qkv_proj.quantized_act_type().is_some() {
109 qkv = qkv.to_dtype(original_dtype)?;
110 }
111 let query_pos = self.num_heads * self.head_dim;
112 let q = qkv.narrow(D::Minus1, 0, query_pos)?;
113 let k = qkv.narrow(D::Minus1, query_pos, self.num_kv_heads * self.head_dim)?;
114 let v = qkv.narrow(
115 D::Minus1,
116 query_pos + self.num_kv_heads * self.head_dim,
117 self.num_kv_heads * self.head_dim,
118 )?;
119
120 let (q, k, v) = if q_len != 1 {
121 let q = q
122 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
123 .transpose(1, 2)?;
124 let k = k
125 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
126 .transpose(1, 2)?;
127 let v = v
128 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
129 .transpose(1, 2)?;
130 (q, k, v)
131 } else {
132 let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
133 let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
134 let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
135 (q, k, v)
136 };
137
138 let (q, k) = self
139 .rotary_emb
140 .forward(&q, &k, seqlen_offsets, position_ids)?;
141
142 let mut attn_output = match &self.paged_attn {
143 Some(paged_attn) => match metadata {
144 Some(((key_cache, value_cache), input_metadata)) => paged_attn.forward(
145 &q,
146 &k.contiguous()?,
147 &v.contiguous()?,
148 attention_mask,
149 Some(key_cache),
150 Some(value_cache),
151 input_metadata,
152 &self.sdpa_params,
153 Some(flash_params),
154 )?,
155 None => {
156 let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
159 assert!(attention_mask.is_some());
161 paged_attn.forward(
162 &q,
163 &k.contiguous()?,
164 &v.contiguous()?,
165 attention_mask,
166 None,
167 None,
168 &input_metadata,
169 &self.sdpa_params,
170 Some(flash_params),
171 )?
172 }
173 },
174 None => {
175 let (k, v) = kv_cache.append(&k, &v)?;
176
177 Sdpa.run_attention(
178 &q,
179 &k,
180 &v,
181 attention_mask,
182 Some(flash_params),
183 &self.sdpa_params,
184 )?
185 }
186 };
187
188 if let Some(t) = self.qkv_proj.quantized_act_type() {
189 attn_output = attn_output.to_dtype(t)?;
190 }
191 attn_output = if attention_mask.is_some() {
192 attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?
193 } else {
194 attn_output.reshape((b_sz, q_len, ()))?
195 };
196 let mut res = MatMul.qmethod_matmul(&attn_output, &*self.o_proj)?;
197 if self.qkv_proj.quantized_act_type().is_some() {
198 res = res.to_dtype(original_dtype)?;
199 }
200 Ok(res)
201 }
202}
203
204#[derive(Clone)]
205struct Mlp {
206 gate_up_proj: Arc<dyn QuantMethod>,
207 down_proj: Arc<dyn QuantMethod>,
208 act_fn: Activation,
209 i_size: usize,
210}
211
212impl Mlp {
213 fn new(cfg: &Phi4MMConfig, vb: ShardedVarBuilder) -> Result<Self> {
214 let hidden_size = cfg.hidden_size;
215 let i_size = cfg.intermediate_size;
216
217 let gate_up_proj = mistralrs_quant::linear_no_bias_static_lora(
219 hidden_size,
220 2 * i_size,
221 cfg.loras(),
222 vb.pp("gate_up_proj"),
223 )?;
224
225 let down_proj = mistralrs_quant::linear_no_bias_static_lora(
226 i_size,
227 hidden_size,
228 cfg.loras(),
229 vb.pp("down_proj"),
230 )?;
231
232 Ok(Self {
233 gate_up_proj,
234 down_proj,
235 act_fn: cfg.hidden_act,
236 i_size,
237 })
238 }
239
240 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
241 let original_dtype = xs.dtype();
242 let mut xs = xs.clone();
243 if let Some(t) = self.gate_up_proj.quantized_act_type() {
244 xs = xs.to_dtype(t)?;
245 }
246 let up_states = MatMul.qmethod_matmul(&xs, &*self.gate_up_proj)?;
247 let gate = up_states.narrow(D::Minus1, 0, self.i_size)?;
248 let up_states = up_states.narrow(D::Minus1, self.i_size, self.i_size)?;
249 let up_states = (up_states * gate.apply(&self.act_fn))?;
250 let mut res = MatMul.qmethod_matmul(&up_states, &*self.down_proj)?;
251 if self.gate_up_proj.quantized_act_type().is_some() {
252 res = res.to_dtype(original_dtype)?;
253 }
254 Ok(res)
255 }
256}
257
258struct DecoderLayer {
259 input_layernorm: RmsNorm,
260 post_attention_layernorm: RmsNorm,
261 mlp: Mlp,
262 self_attn: Attention,
263}
264
265impl DecoderLayer {
266 fn new(
267 rotary_emb: Arc<Phi4MMRotaryEmbedding>,
268 cfg: &Phi4MMConfig,
269 vb: ShardedVarBuilder,
270 mapper: &dyn DeviceMapper,
271 layer_idx: usize,
272 loading_isq: bool,
273 paged_attn: Option<PagedAttention>,
274 ) -> Result<Self> {
275 let self_attn = Attention::new(
276 rotary_emb,
277 cfg,
278 mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
279 paged_attn,
280 )?;
281 let mlp = Mlp::new(cfg, mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq))?;
282 let input_layernorm = RmsNorm::new(
283 cfg.hidden_size,
284 cfg.rms_norm_eps,
285 mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
286 )?;
287 let post_attention_layernorm = RmsNorm::new(
288 cfg.hidden_size,
289 cfg.rms_norm_eps,
290 mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
291 )?;
292
293 Ok(Self {
294 input_layernorm,
295 post_attention_layernorm,
296 mlp,
297 self_attn,
298 })
299 }
300
301 #[allow(clippy::too_many_arguments)]
302 fn forward(
303 &self,
304 xs: &Tensor,
305 attention_mask: Option<&Tensor>,
306 seqlen_offsets: &[usize],
307 position_ids: &[usize],
308 kv_cache: &mut KvCache,
309 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
310 flash_params: &FlashParams,
311 ) -> Result<Tensor> {
312 let residual = xs;
313 let xs = self.input_layernorm.forward(xs)?;
314 let xs = self.self_attn.forward(
315 &xs,
316 attention_mask,
317 seqlen_offsets,
318 position_ids,
319 kv_cache,
320 metadata,
321 flash_params,
322 )?;
323 let xs = (xs + residual)?;
324 let residual = &xs;
325 let xs = self
326 .mlp
327 .forward(&xs.apply(&self.post_attention_layernorm)?)?;
328 residual + xs
329 }
330}
331
332pub struct Phi4MMModel {
333 embed_tokens: candle_nn::Embedding,
334 embed_tokens_extend: Phi4MMImageAudioEmbedding,
335 layers: Vec<DecoderLayer>,
336 norm: RmsNorm,
337 lm_head: Arc<dyn QuantMethod>,
338 device: Device,
339 cache: EitherCache,
340 max_seq_len: usize,
341 mapper: Box<dyn DeviceMapper + Send + Sync>,
342 sliding_window: Option<usize>,
343 cfg: ModelConfigMetadata,
344}
345
346impl Phi4MMModel {
347 pub fn new(
348 cfg: &Phi4MMConfig,
349 vb: ShardedVarBuilder,
350 _is_gptx: bool,
351 normal_loading_metadata: NormalLoadingMetadata,
352 attention_mechanism: AttentionImplementation,
353 ) -> Result<Self> {
354 let mapper = normal_loading_metadata.mapper;
355 let vb_m = vb.pp("model");
356
357 let embed_tokens = layers::embedding(
358 cfg.vocab_size,
359 cfg.hidden_size,
360 mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
361 &cfg.quantization_config,
362 )?;
363
364 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
365 let vb_l = vb_m.pp("layers");
366 let mut ropes = HashMap::new();
367 for layer_idx in 0..cfg.num_hidden_layers {
368 let device = mapper
369 .device_for(layer_idx, false)
370 .unwrap_or(&normal_loading_metadata.real_device);
371 ropes.insert(
372 device.location(),
373 Arc::new(Phi4MMRotaryEmbedding::new(vb.dtype(), cfg, device)?),
374 );
375 }
376 for layer_idx in NiceProgressBar::<_, 'b'>(
377 0..cfg.num_hidden_layers,
378 "Loading repeating layers",
379 &normal_loading_metadata.multi_progress,
380 ) {
381 let device = mapper
382 .device_for(layer_idx, false)
383 .unwrap_or(&normal_loading_metadata.real_device);
384 let rotary_emb = ropes
385 .get(&device.location())
386 .expect("No RoPE for device location!")
387 .clone();
388 let paged_attn = match &attention_mechanism {
389 AttentionImplementation::Eager => None,
390 AttentionImplementation::PagedAttention => {
391 Some(PagedAttention::new(cfg.head_dim(), device, None)?)
392 }
393 };
394 let layer = DecoderLayer::new(
395 rotary_emb.clone(),
396 cfg,
397 vb_l.pp(layer_idx),
398 &*mapper,
399 layer_idx,
400 normal_loading_metadata.loading_isq,
401 paged_attn,
402 )?;
403 layers.push(layer)
404 }
405 let norm = RmsNorm::new(
406 cfg.hidden_size,
407 cfg.rms_norm_eps,
408 mapper.set_nm_device(vb_m.pp("norm"), false),
409 )?;
410 let lm_head = if !cfg.tie_word_embeddings {
411 ReplicatedLayer::new(
412 cfg.hidden_size,
413 cfg.vocab_size,
414 &None,
415 false,
416 mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
417 )?
418 } else {
419 ReplicatedLayer::from_linear(candle_nn::Linear::new(
420 mapper.cast_nm_device(
421 embed_tokens.embeddings(),
422 normal_loading_metadata.loading_isq,
423 )?,
424 None,
425 ))?
426 };
427
428 let embed_tokens_extend = Phi4MMImageAudioEmbedding::new(
429 cfg,
430 embed_tokens.clone(),
431 mapper.set_nm_device(vb_m.pp("embed_tokens_extend"), false),
432 )?;
433
434 Ok(Self {
435 layers,
436 norm,
437 lm_head,
438 device: normal_loading_metadata.real_device,
439 cache: EitherCache::Normal(NormalCache::new_sliding(
440 cfg.num_hidden_layers,
441 cfg.max_position_embeddings,
442 cfg.sliding_window,
443 )),
444 max_seq_len: cfg.max_position_embeddings,
445 sliding_window: cfg.sliding_window,
446 embed_tokens,
447 cfg: ModelConfigMetadata {
448 max_seq_len: cfg.max_position_embeddings,
449 num_layers: cfg.num_hidden_layers,
450 hidden_size: cfg.hidden_size,
451 num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
452 num_kv_heads: (cfg.num_key_value_heads() / mapper.get_comm_for(0)?.world_size())
453 .max(1),
454 sliding_window: cfg.sliding_window,
455 k_head_dim: cfg.head_dim(),
456 v_head_dim: cfg.head_dim(),
457 },
458 mapper,
459 embed_tokens_extend,
460 })
461 }
462
463 #[allow(clippy::too_many_arguments)]
464 pub fn forward(
465 &self,
466 input_ids: &Tensor,
467 input_image_embeds: Option<Tensor>,
468 image_attention_mask: Option<Tensor>,
469 seqlen_offsets: &[usize],
470 position_ids: &[usize],
471 context_lens: Vec<(usize, usize)>,
472 image_sizes: Option<Vec<(u32, u32)>>,
473 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
474 flash_params: &FlashParams,
475 ) -> Result<Tensor> {
476 let mut xs = if let Some(input_image_embeds) = &input_image_embeds {
477 self.embed_tokens_extend.forward(
478 input_ids,
479 input_image_embeds,
480 image_attention_mask.as_ref(),
481 image_sizes,
482 )?
483 } else {
484 self.embed_tokens.forward(input_ids)?
485 };
486 let cache = &mut self.cache.normal().0;
487 let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
488 input_ids,
489 metadata
490 .as_ref()
491 .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
492 .unwrap_or(&*cache as &dyn PastKvLenCache),
493 self.sliding_window,
494 xs.dtype(),
495 self.cfg.num_attn_heads,
496 )?;
497 let attention_mask = attention_mask.filter(|_| {
498 metadata
499 .as_ref()
500 .map(|(_, meta)| meta.is_first_prompt_chunk)
501 .unwrap_or(true)
502 });
503
504 for (i, layer) in self.layers.iter().enumerate() {
505 xs = self.mapper.map(xs, i)?;
506 xs = layer.forward(
507 &xs,
508 attention_mask
509 .as_ref()
510 .map(|m| m.to_device(xs.device()).unwrap())
511 .as_ref(),
512 seqlen_offsets,
513 position_ids,
514 &mut cache[i],
515 metadata
516 .as_ref()
517 .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
518 flash_params,
519 )?
520 }
521 let xs = xs.to_device(&self.device)?;
522 let mut xs = xs.apply(&self.norm)?;
523 if let Some(t) = self.lm_head.quantized_act_type() {
524 xs = xs.to_dtype(t)?;
525 }
526 extract_logits(&MatMul.qmethod_matmul(&xs, &*self.lm_head)?, context_lens)
527 }
528}
529
530#[derive(Default)]
531pub(crate) struct Phi4MMVisionSpecificArgs {
532 pub image_sizes: Option<Vec<(u32, u32)>>,
533 pub input_image_embeds: Option<Tensor>,
534 pub image_attention_mask: Option<Tensor>,
535}
536
537impl VisionModel for Phi4MMModel {
538 fn forward(
539 &self,
540 input_ids: &Tensor,
541 _pixel_values: Option<Tensor>,
542 seqlen_offsets: &[usize],
543 context_lens: Vec<(usize, usize)>,
544 position_ids: Vec<usize>,
545 model_specific_args: Box<dyn Any>,
546 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
547 flash_params: &FlashParams,
548 ) -> Result<Tensor> {
549 let Phi4MMVisionSpecificArgs {
550 image_sizes,
551 image_attention_mask,
552 input_image_embeds,
553 } = *model_specific_args
554 .downcast()
555 .expect("Cannot downcast into `Phi4MMVisionSpecificArgs`");
556 self.forward(
557 input_ids,
558 input_image_embeds,
559 image_attention_mask,
560 seqlen_offsets,
561 &position_ids,
562 context_lens,
563 image_sizes,
564 metadata,
565 flash_params,
566 )
567 }
568 fn cache(&self) -> &EitherCache {
569 &self.cache
570 }
571 fn cache_mut(&mut self) -> &mut EitherCache {
572 &mut self.cache
573 }
574 fn device(&self) -> &Device {
575 &self.device
576 }
577 fn max_seq_len(&self) -> usize {
578 self.max_seq_len
579 }
580 fn has_conv2d(&self) -> bool {
581 true
582 }
583 fn config(&self) -> &ModelConfigMetadata {
584 &self.cfg
585 }
586 fn default_model_specific_args(&self, _input_ids: &Tensor) -> Box<dyn Any> {
587 Box::new(Phi4MMVisionSpecificArgs::default())
588 }
589}
590
591impl IsqModel for Phi4MMModel {
592 fn get_layers(
593 &mut self,
594 ) -> (
595 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
596 &dyn DeviceMapper,
597 ) {
598 let mut tensors = Vec::new();
599 tensors.push((&mut self.lm_head, None));
600 for (i, layer) in self.layers.iter_mut().enumerate() {
601 tensors.push((&mut layer.self_attn.qkv_proj, Some(i)));
602 tensors.push((&mut layer.self_attn.o_proj, Some(i)));
603 tensors.push((&mut layer.mlp.gate_up_proj, Some(i)));
604 tensors.push((&mut layer.mlp.down_proj, Some(i)));
605 }
606 (tensors, &*self.mapper)
607 }
608
609 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
610 let uvb = UnVarBuilder::new();
611
612 let uvb_m = uvb.pp("model");
613 uvb_m.pp("embed_tokens").add(&self.embed_tokens);
614 uvb_m.pp("norm").add(&self.norm);
615 uvb_m
616 .pp("embed_tokens_extend")
617 .extend(self.embed_tokens_extend.residual_tensors());
618
619 for (layer_idx, layer) in self.layers.iter().enumerate() {
620 let uvb_l = uvb_m.pp("layers").pp(layer_idx);
621 uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
622 uvb_l
623 .pp("post_attention_layernorm")
624 .add(&layer.post_attention_layernorm);
625 }
626
627 uvb.to_safetensors()
628 }
629}
630
631impl AnyMoeBaseModelMixin for Phi4MMModel {}