1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{any::Any, collections::HashMap, sync::Arc};
4
5use candle_core::{Device, Result, Tensor, D};
6use candle_nn::Module;
7use mistralrs_quant::{MatMul, QuantMethod, ReplicatedLayer, ShardedVarBuilder};
8use mm_embedding::Phi4MMImageAudioEmbedding;
9
10use crate::{
11 amoe::AnyMoeBaseModelMixin,
12 attention::SdpaParams,
13 device_map::DeviceMapper,
14 layers::{self, Activation, CausalMasker, Phi4MMRotaryEmbedding, RmsNorm, Sdpa},
15 layers_masker::PastKvLenCache,
16 paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
17 pipeline::{
18 extract_logits,
19 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
20 EitherCache, IsqModel, KvCache, NormalCache, NormalLoadingMetadata, VisionModel,
21 },
22 utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
23};
24
25mod config;
26mod image_embedding;
27pub(crate) mod inputs_processor;
28mod mm_embedding;
29
30pub(crate) use config::Phi4MMConfig;
31pub(crate) use image_embedding::PHI4_MM_VISION_CFG;
32
33struct Attention {
34 qkv_proj: Arc<dyn QuantMethod>,
35 o_proj: Arc<dyn QuantMethod>,
36 num_heads: usize,
37 num_kv_heads: usize,
38 head_dim: usize,
39 rotary_emb: Arc<Phi4MMRotaryEmbedding>,
40 paged_attn: Option<PagedAttention>,
41 sdpa_params: SdpaParams,
42}
43
44impl Attention {
45 fn new(
46 rotary_emb: Arc<Phi4MMRotaryEmbedding>,
47 cfg: &Phi4MMConfig,
48 vb: ShardedVarBuilder,
49 paged_attn: Option<PagedAttention>,
50 ) -> Result<Self> {
51 let num_heads = cfg.num_attention_heads;
52 let num_kv_heads = cfg.num_key_value_heads();
53 let head_dim = cfg.head_dim();
54 let op_size = num_heads * head_dim + 2 * num_kv_heads * head_dim;
55
56 let qkv_proj = mistralrs_quant::linear_no_bias_static_lora(
58 cfg.hidden_size,
59 op_size,
60 cfg.loras(),
61 vb.pp("qkv_proj"),
62 )?;
63
64 let o_proj = mistralrs_quant::linear_no_bias_static_lora(
65 num_heads * head_dim,
66 cfg.hidden_size,
67 cfg.loras(),
68 vb.pp("o_proj"),
69 )?;
70
71 Ok(Self {
72 qkv_proj,
73 o_proj,
74 rotary_emb,
75 num_heads,
76 num_kv_heads,
77 head_dim,
78 paged_attn,
79 sdpa_params: SdpaParams {
80 n_kv_groups: num_heads / num_kv_heads,
81 use_flash_attn: cfg.use_flash_attn,
82 softcap: None,
83 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
84 sliding_window: cfg.sliding_window,
85 },
86 })
87 }
88
89 #[allow(clippy::too_many_arguments)]
90 fn forward(
91 &self,
92 xs: &Tensor,
93 attention_mask: Option<&Tensor>,
94 seqlen_offsets: &[usize],
95 position_ids: &[usize],
96 kv_cache: &mut KvCache,
97 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
98 flash_params: &FlashParams,
99 ) -> Result<Tensor> {
100 let (b_sz, q_len, _) = xs.dims3()?;
101
102 let original_dtype = xs.dtype();
103 let mut xs = xs.clone();
104 if let Some(t) = self.qkv_proj.quantized_act_type() {
105 xs = xs.to_dtype(t)?;
106 }
107 let mut qkv = MatMul.qmethod_matmul(&xs, &*self.qkv_proj)?;
108 if self.qkv_proj.quantized_act_type().is_some() {
109 qkv = qkv.to_dtype(original_dtype)?;
110 }
111 let query_pos = self.num_heads * self.head_dim;
112 let q = qkv.narrow(D::Minus1, 0, query_pos)?;
113 let k = qkv.narrow(D::Minus1, query_pos, self.num_kv_heads * self.head_dim)?;
114 let v = qkv.narrow(
115 D::Minus1,
116 query_pos + self.num_kv_heads * self.head_dim,
117 self.num_kv_heads * self.head_dim,
118 )?;
119
120 let (q, k, v) = if q_len != 1 {
121 let q = q
122 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
123 .transpose(1, 2)?;
124 let k = k
125 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
126 .transpose(1, 2)?;
127 let v = v
128 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
129 .transpose(1, 2)?;
130 (q, k, v)
131 } else {
132 let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
133 let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
134 let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
135 (q, k, v)
136 };
137
138 let (q, k) = self
139 .rotary_emb
140 .forward(&q, &k, seqlen_offsets, position_ids)?;
141
142 let mut attn_output = match &self.paged_attn {
143 Some(paged_attn) => match metadata {
144 Some(((key_cache, value_cache), input_metadata)) => paged_attn.forward(
145 &q,
146 &k.contiguous()?,
147 &v.contiguous()?,
148 attention_mask,
149 Some(key_cache),
150 Some(value_cache),
151 input_metadata,
152 &self.sdpa_params,
153 Some(flash_params),
154 )?,
155 None => {
156 let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
159 assert!(attention_mask.is_some());
161 paged_attn.forward(
162 &q,
163 &k.contiguous()?,
164 &v.contiguous()?,
165 attention_mask,
166 None,
167 None,
168 &input_metadata,
169 &self.sdpa_params,
170 Some(flash_params),
171 )?
172 }
173 },
174 None => {
175 let (k, v) = kv_cache.append(&k, &v)?;
176
177 Sdpa.run_attention(
178 &q,
179 &k,
180 &v,
181 attention_mask,
182 Some(flash_params),
183 &self.sdpa_params,
184 )?
185 }
186 };
187
188 if let Some(t) = self.qkv_proj.quantized_act_type() {
189 attn_output = attn_output.to_dtype(t)?;
190 }
191 attn_output = if attention_mask.is_some() {
192 attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?
193 } else {
194 attn_output.reshape((b_sz, q_len, ()))?
195 };
196 let mut res = MatMul.qmethod_matmul(&attn_output, &*self.o_proj)?;
197 if self.qkv_proj.quantized_act_type().is_some() {
198 res = res.to_dtype(original_dtype)?;
199 }
200 Ok(res)
201 }
202}
203
204#[derive(Clone)]
205struct Mlp {
206 gate_up_proj: Arc<dyn QuantMethod>,
207 down_proj: Arc<dyn QuantMethod>,
208 act_fn: Activation,
209 i_size: usize,
210}
211
212impl Mlp {
213 fn new(cfg: &Phi4MMConfig, vb: ShardedVarBuilder) -> Result<Self> {
214 let hidden_size = cfg.hidden_size;
215 let i_size = cfg.intermediate_size;
216
217 let gate_up_proj = mistralrs_quant::linear_no_bias_static_lora(
219 hidden_size,
220 2 * i_size,
221 cfg.loras(),
222 vb.pp("gate_up_proj"),
223 )?;
224
225 let down_proj = mistralrs_quant::linear_no_bias_static_lora(
226 i_size,
227 hidden_size,
228 cfg.loras(),
229 vb.pp("down_proj"),
230 )?;
231
232 Ok(Self {
233 gate_up_proj,
234 down_proj,
235 act_fn: cfg.hidden_act,
236 i_size,
237 })
238 }
239
240 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
241 let original_dtype = xs.dtype();
242 let mut xs = xs.clone();
243 if let Some(t) = self.gate_up_proj.quantized_act_type() {
244 xs = xs.to_dtype(t)?;
245 }
246 let up_states = MatMul.qmethod_matmul(&xs, &*self.gate_up_proj)?;
247 let gate = up_states.narrow(D::Minus1, 0, self.i_size)?;
248 let up_states = up_states.narrow(D::Minus1, self.i_size, self.i_size)?;
249 let up_states = (up_states * gate.apply(&self.act_fn))?;
250 let mut res = MatMul.qmethod_matmul(&up_states, &*self.down_proj)?;
251 if self.gate_up_proj.quantized_act_type().is_some() {
252 res = res.to_dtype(original_dtype)?;
253 }
254 Ok(res)
255 }
256}
257
258struct DecoderLayer {
259 input_layernorm: RmsNorm,
260 post_attention_layernorm: RmsNorm,
261 mlp: Mlp,
262 self_attn: Attention,
263}
264
265impl DecoderLayer {
266 fn new(
267 rotary_emb: Arc<Phi4MMRotaryEmbedding>,
268 cfg: &Phi4MMConfig,
269 vb: ShardedVarBuilder,
270 mapper: &dyn DeviceMapper,
271 layer_idx: usize,
272 loading_isq: bool,
273 paged_attn: Option<PagedAttention>,
274 ) -> Result<Self> {
275 let self_attn = Attention::new(
276 rotary_emb,
277 cfg,
278 mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
279 paged_attn,
280 )?;
281 let mlp = Mlp::new(cfg, mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq))?;
282 let input_layernorm = RmsNorm::new(
283 cfg.hidden_size,
284 cfg.rms_norm_eps,
285 mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
286 )?;
287 let post_attention_layernorm = RmsNorm::new(
288 cfg.hidden_size,
289 cfg.rms_norm_eps,
290 mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
291 )?;
292
293 Ok(Self {
294 input_layernorm,
295 post_attention_layernorm,
296 mlp,
297 self_attn,
298 })
299 }
300
301 #[allow(clippy::too_many_arguments)]
302 fn forward(
303 &self,
304 xs: &Tensor,
305 attention_mask: Option<&Tensor>,
306 seqlen_offsets: &[usize],
307 position_ids: &[usize],
308 kv_cache: &mut KvCache,
309 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
310 flash_params: &FlashParams,
311 ) -> Result<Tensor> {
312 let residual = xs;
313 let xs = self.input_layernorm.forward(xs)?;
314 let xs = self.self_attn.forward(
315 &xs,
316 attention_mask,
317 seqlen_offsets,
318 position_ids,
319 kv_cache,
320 metadata,
321 flash_params,
322 )?;
323 let xs = (xs + residual)?;
324 let residual = &xs;
325 let xs = self
326 .mlp
327 .forward(&xs.apply(&self.post_attention_layernorm)?)?;
328 residual + xs
329 }
330}
331
332pub struct Phi4MMModel {
333 embed_tokens: candle_nn::Embedding,
334 embed_tokens_extend: Phi4MMImageAudioEmbedding,
335 layers: Vec<DecoderLayer>,
336 norm: RmsNorm,
337 lm_head: Arc<dyn QuantMethod>,
338 device: Device,
339 cache: EitherCache,
340 max_seq_len: usize,
341 mapper: Box<dyn DeviceMapper + Send + Sync>,
342 sliding_window: Option<usize>,
343 cfg: ModelConfigMetadata,
344}
345
346impl Phi4MMModel {
347 pub fn new(
348 cfg: &Phi4MMConfig,
349 vb: ShardedVarBuilder,
350 _is_gptx: bool,
351 normal_loading_metadata: NormalLoadingMetadata,
352 attention_mechanism: AttentionImplementation,
353 ) -> Result<Self> {
354 let mapper = normal_loading_metadata.mapper;
355 let vb_m = vb.pp("model");
356
357 let embed_tokens = layers::embedding(
358 cfg.vocab_size,
359 cfg.hidden_size,
360 mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
361 )?;
362
363 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
364 let vb_l = vb_m.pp("layers");
365 let mut ropes = HashMap::new();
366 for layer_idx in 0..cfg.num_hidden_layers {
367 let device = mapper
368 .device_for(layer_idx, false)
369 .unwrap_or(&normal_loading_metadata.real_device);
370 ropes.insert(
371 device.location(),
372 Arc::new(Phi4MMRotaryEmbedding::new(vb.dtype(), cfg, device)?),
373 );
374 }
375 for layer_idx in NiceProgressBar::<_, 'b'>(
376 0..cfg.num_hidden_layers,
377 "Loading repeating layers",
378 &normal_loading_metadata.multi_progress,
379 ) {
380 let device = mapper
381 .device_for(layer_idx, false)
382 .unwrap_or(&normal_loading_metadata.real_device);
383 let rotary_emb = ropes
384 .get(&device.location())
385 .expect("No RoPE for device location!")
386 .clone();
387 let paged_attn = match &attention_mechanism {
388 AttentionImplementation::Eager => None,
389 AttentionImplementation::PagedAttention => {
390 Some(PagedAttention::new(cfg.head_dim(), device, None)?)
391 }
392 };
393 let layer = DecoderLayer::new(
394 rotary_emb.clone(),
395 cfg,
396 vb_l.pp(layer_idx),
397 &*mapper,
398 layer_idx,
399 normal_loading_metadata.loading_isq,
400 paged_attn,
401 )?;
402 layers.push(layer)
403 }
404 let norm = RmsNorm::new(
405 cfg.hidden_size,
406 cfg.rms_norm_eps,
407 mapper.set_nm_device(vb_m.pp("norm"), false),
408 )?;
409 let lm_head = if !cfg.tie_word_embeddings {
410 ReplicatedLayer::new(
411 cfg.hidden_size,
412 cfg.vocab_size,
413 &None,
414 false,
415 mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
416 )?
417 } else {
418 ReplicatedLayer::from_linear(candle_nn::Linear::new(
419 mapper.cast_nm_device(
420 embed_tokens.embeddings(),
421 normal_loading_metadata.loading_isq,
422 )?,
423 None,
424 ))?
425 };
426
427 let embed_tokens_extend = Phi4MMImageAudioEmbedding::new(
428 cfg,
429 embed_tokens.clone(),
430 mapper.set_nm_device(vb_m.pp("embed_tokens_extend"), false),
431 )?;
432
433 Ok(Self {
434 layers,
435 norm,
436 lm_head,
437 device: normal_loading_metadata.real_device,
438 cache: EitherCache::Normal(NormalCache::new_sliding(
439 cfg.num_hidden_layers,
440 cfg.max_position_embeddings,
441 cfg.sliding_window,
442 )),
443 max_seq_len: cfg.max_position_embeddings,
444 sliding_window: cfg.sliding_window,
445 embed_tokens,
446 cfg: ModelConfigMetadata {
447 max_seq_len: cfg.max_position_embeddings,
448 num_layers: cfg.num_hidden_layers,
449 hidden_size: cfg.hidden_size,
450 num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
451 num_kv_heads: (cfg.num_key_value_heads() / mapper.get_comm_for(0)?.world_size())
452 .max(1),
453 sliding_window: cfg.sliding_window,
454 k_head_dim: cfg.head_dim(),
455 v_head_dim: cfg.head_dim(),
456 },
457 mapper,
458 embed_tokens_extend,
459 })
460 }
461
462 #[allow(clippy::too_many_arguments)]
463 pub fn forward(
464 &self,
465 input_ids: &Tensor,
466 input_image_embeds: Option<Tensor>,
467 image_attention_mask: Option<Tensor>,
468 seqlen_offsets: &[usize],
469 position_ids: &[usize],
470 context_lens: Vec<(usize, usize)>,
471 image_sizes: Option<Vec<(u32, u32)>>,
472 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
473 flash_params: &FlashParams,
474 ) -> Result<Tensor> {
475 let mut xs = if let Some(input_image_embeds) = &input_image_embeds {
476 self.embed_tokens_extend.forward(
477 input_ids,
478 input_image_embeds,
479 image_attention_mask.as_ref(),
480 image_sizes,
481 )?
482 } else {
483 self.embed_tokens.forward(input_ids)?
484 };
485 let cache = &mut self.cache.normal().0;
486 let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
487 input_ids,
488 metadata
489 .as_ref()
490 .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
491 .unwrap_or(&*cache as &dyn PastKvLenCache),
492 self.sliding_window,
493 xs.dtype(),
494 self.cfg.num_attn_heads,
495 )?;
496 let attention_mask = attention_mask.filter(|_| {
497 metadata
498 .as_ref()
499 .map(|(_, meta)| meta.is_first_prompt_chunk)
500 .unwrap_or(true)
501 });
502
503 for (i, layer) in self.layers.iter().enumerate() {
504 xs = self.mapper.map(xs, i)?;
505 xs = layer.forward(
506 &xs,
507 attention_mask
508 .as_ref()
509 .map(|m| m.to_device(xs.device()).unwrap())
510 .as_ref(),
511 seqlen_offsets,
512 position_ids,
513 &mut cache[i],
514 metadata
515 .as_ref()
516 .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
517 flash_params,
518 )?
519 }
520 let xs = xs.to_device(&self.device)?;
521 let mut xs = xs.apply(&self.norm)?;
522 if let Some(t) = self.lm_head.quantized_act_type() {
523 xs = xs.to_dtype(t)?;
524 }
525 extract_logits(&MatMul.qmethod_matmul(&xs, &*self.lm_head)?, context_lens)
526 }
527}
528
529#[derive(Default)]
530pub(crate) struct Phi4MMVisionSpecificArgs {
531 pub image_sizes: Option<Vec<(u32, u32)>>,
532 pub input_image_embeds: Option<Tensor>,
533 pub image_attention_mask: Option<Tensor>,
534}
535
536impl VisionModel for Phi4MMModel {
537 fn forward(
538 &self,
539 input_ids: &Tensor,
540 _pixel_values: Option<Tensor>,
541 seqlen_offsets: &[usize],
542 context_lens: Vec<(usize, usize)>,
543 position_ids: Vec<usize>,
544 model_specific_args: Box<dyn Any>,
545 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
546 flash_params: &FlashParams,
547 ) -> Result<Tensor> {
548 let Phi4MMVisionSpecificArgs {
549 image_sizes,
550 image_attention_mask,
551 input_image_embeds,
552 } = *model_specific_args
553 .downcast()
554 .expect("Cannot downcast into `Phi4MMVisionSpecificArgs`");
555 self.forward(
556 input_ids,
557 input_image_embeds,
558 image_attention_mask,
559 seqlen_offsets,
560 &position_ids,
561 context_lens,
562 image_sizes,
563 metadata,
564 flash_params,
565 )
566 }
567 fn cache(&self) -> &EitherCache {
568 &self.cache
569 }
570 fn cache_mut(&mut self) -> &mut EitherCache {
571 &mut self.cache
572 }
573 fn device(&self) -> &Device {
574 &self.device
575 }
576 fn max_seq_len(&self) -> usize {
577 self.max_seq_len
578 }
579 fn has_conv2d(&self) -> bool {
580 true
581 }
582 fn config(&self) -> &ModelConfigMetadata {
583 &self.cfg
584 }
585 fn default_model_specific_args(&self, _input_ids: &Tensor) -> Box<dyn Any> {
586 Box::new(Phi4MMVisionSpecificArgs::default())
587 }
588}
589
590impl IsqModel for Phi4MMModel {
591 fn get_layers(
592 &mut self,
593 ) -> (
594 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
595 &dyn DeviceMapper,
596 ) {
597 let mut tensors = Vec::new();
598 tensors.push((&mut self.lm_head, None));
599 for (i, layer) in self.layers.iter_mut().enumerate() {
600 tensors.push((&mut layer.self_attn.qkv_proj, Some(i)));
601 tensors.push((&mut layer.self_attn.o_proj, Some(i)));
602 tensors.push((&mut layer.mlp.gate_up_proj, Some(i)));
603 tensors.push((&mut layer.mlp.down_proj, Some(i)));
604 }
605 (tensors, &*self.mapper)
606 }
607
608 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
609 let uvb = UnVarBuilder::new();
610
611 let uvb_m = uvb.pp("model");
612 uvb_m.pp("embed_tokens").add(&self.embed_tokens);
613 uvb_m.pp("norm").add(&self.norm);
614 uvb_m
615 .pp("embed_tokens_extend")
616 .extend(self.embed_tokens_extend.residual_tensors());
617
618 for (layer_idx, layer) in self.layers.iter().enumerate() {
619 let uvb_l = uvb_m.pp("layers").pp(layer_idx);
620 uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
621 uvb_l
622 .pp("post_attention_layernorm")
623 .add(&layer.post_attention_layernorm);
624 }
625
626 uvb.to_safetensors()
627 }
628}
629
630impl AnyMoeBaseModelMixin for Phi4MMModel {}