mistralrs_core/vision_models/phi4/
mod.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{any::Any, collections::HashMap, sync::Arc};
4
5use candle_core::{Device, Result, Tensor, D};
6use candle_nn::Module;
7use mistralrs_quant::{MatMul, QuantMethod, ReplicatedLayer, ShardedVarBuilder};
8use mm_embedding::Phi4MMImageAudioEmbedding;
9
10use crate::{
11    amoe::AnyMoeBaseModelMixin,
12    attention::SdpaParams,
13    device_map::DeviceMapper,
14    layers::{self, Activation, CausalMasker, Phi4MMRotaryEmbedding, RmsNorm, Sdpa},
15    layers_masker::PastKvLenCache,
16    paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
17    pipeline::{
18        extract_logits,
19        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
20        EitherCache, IsqModel, KvCache, NormalCache, NormalLoadingMetadata, VisionModel,
21    },
22    utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
23};
24
25mod config;
26mod image_embedding;
27pub(crate) mod inputs_processor;
28mod mm_embedding;
29
30pub(crate) use config::Phi4MMConfig;
31pub(crate) use image_embedding::PHI4_MM_VISION_CFG;
32
33struct Attention {
34    qkv_proj: Arc<dyn QuantMethod>,
35    o_proj: Arc<dyn QuantMethod>,
36    num_heads: usize,
37    num_kv_heads: usize,
38    head_dim: usize,
39    rotary_emb: Arc<Phi4MMRotaryEmbedding>,
40    paged_attn: Option<PagedAttention>,
41    sdpa_params: SdpaParams,
42}
43
44impl Attention {
45    fn new(
46        rotary_emb: Arc<Phi4MMRotaryEmbedding>,
47        cfg: &Phi4MMConfig,
48        vb: ShardedVarBuilder,
49        paged_attn: Option<PagedAttention>,
50    ) -> Result<Self> {
51        let num_heads = cfg.num_attention_heads;
52        let num_kv_heads = cfg.num_key_value_heads();
53        let head_dim = cfg.head_dim();
54        let op_size = num_heads * head_dim + 2 * num_kv_heads * head_dim;
55
56        // No TP here.
57        let qkv_proj = mistralrs_quant::linear_no_bias_static_lora(
58            cfg.hidden_size,
59            op_size,
60            cfg.loras(),
61            vb.pp("qkv_proj"),
62        )?;
63
64        let o_proj = mistralrs_quant::linear_no_bias_static_lora(
65            num_heads * head_dim,
66            cfg.hidden_size,
67            cfg.loras(),
68            vb.pp("o_proj"),
69        )?;
70
71        Ok(Self {
72            qkv_proj,
73            o_proj,
74            rotary_emb,
75            num_heads,
76            num_kv_heads,
77            head_dim,
78            paged_attn,
79            sdpa_params: SdpaParams {
80                n_kv_groups: num_heads / num_kv_heads,
81                use_flash_attn: cfg.use_flash_attn,
82                softcap: None,
83                softmax_scale: 1.0 / (head_dim as f32).sqrt(),
84                sliding_window: cfg.sliding_window,
85            },
86        })
87    }
88
89    #[allow(clippy::too_many_arguments)]
90    fn forward(
91        &self,
92        xs: &Tensor,
93        attention_mask: Option<&Tensor>,
94        seqlen_offsets: &[usize],
95        position_ids: &[usize],
96        kv_cache: &mut KvCache,
97        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
98        flash_params: &FlashParams,
99    ) -> Result<Tensor> {
100        let (b_sz, q_len, _) = xs.dims3()?;
101
102        let original_dtype = xs.dtype();
103        let mut xs = xs.clone();
104        if let Some(t) = self.qkv_proj.quantized_act_type() {
105            xs = xs.to_dtype(t)?;
106        }
107        let mut qkv = MatMul.qmethod_matmul(&xs, &*self.qkv_proj)?;
108        if self.qkv_proj.quantized_act_type().is_some() {
109            qkv = qkv.to_dtype(original_dtype)?;
110        }
111        let query_pos = self.num_heads * self.head_dim;
112        let q = qkv.narrow(D::Minus1, 0, query_pos)?;
113        let k = qkv.narrow(D::Minus1, query_pos, self.num_kv_heads * self.head_dim)?;
114        let v = qkv.narrow(
115            D::Minus1,
116            query_pos + self.num_kv_heads * self.head_dim,
117            self.num_kv_heads * self.head_dim,
118        )?;
119
120        let (q, k, v) = if q_len != 1 {
121            let q = q
122                .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
123                .transpose(1, 2)?;
124            let k = k
125                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
126                .transpose(1, 2)?;
127            let v = v
128                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
129                .transpose(1, 2)?;
130            (q, k, v)
131        } else {
132            let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
133            let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
134            let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
135            (q, k, v)
136        };
137
138        let (q, k) = self
139            .rotary_emb
140            .forward(&q, &k, seqlen_offsets, position_ids)?;
141
142        let mut attn_output = match &self.paged_attn {
143            Some(paged_attn) => match metadata {
144                Some(((key_cache, value_cache), input_metadata)) => paged_attn.forward(
145                    &q,
146                    &k.contiguous()?,
147                    &v.contiguous()?,
148                    attention_mask,
149                    Some(key_cache),
150                    Some(value_cache),
151                    input_metadata,
152                    &self.sdpa_params,
153                    Some(flash_params),
154                )?,
155                None => {
156                    // If we don't have metadata, we are most likely generating an imatrix so we don't want to populate that.
157                    // Generating the dummy metadata with the assumption that we are not generating text (only processing prompts).
158                    let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
159                    // Sanity check.
160                    assert!(attention_mask.is_some());
161                    paged_attn.forward(
162                        &q,
163                        &k.contiguous()?,
164                        &v.contiguous()?,
165                        attention_mask,
166                        None,
167                        None,
168                        &input_metadata,
169                        &self.sdpa_params,
170                        Some(flash_params),
171                    )?
172                }
173            },
174            None => {
175                let (k, v) = kv_cache.append(&k, &v)?;
176
177                Sdpa.run_attention(
178                    &q,
179                    &k,
180                    &v,
181                    attention_mask,
182                    Some(flash_params),
183                    &self.sdpa_params,
184                )?
185            }
186        };
187
188        if let Some(t) = self.qkv_proj.quantized_act_type() {
189            attn_output = attn_output.to_dtype(t)?;
190        }
191        attn_output = if attention_mask.is_some() {
192            attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?
193        } else {
194            attn_output.reshape((b_sz, q_len, ()))?
195        };
196        let mut res = MatMul.qmethod_matmul(&attn_output, &*self.o_proj)?;
197        if self.qkv_proj.quantized_act_type().is_some() {
198            res = res.to_dtype(original_dtype)?;
199        }
200        Ok(res)
201    }
202}
203
204#[derive(Clone)]
205struct Mlp {
206    gate_up_proj: Arc<dyn QuantMethod>,
207    down_proj: Arc<dyn QuantMethod>,
208    act_fn: Activation,
209    i_size: usize,
210}
211
212impl Mlp {
213    fn new(cfg: &Phi4MMConfig, vb: ShardedVarBuilder) -> Result<Self> {
214        let hidden_size = cfg.hidden_size;
215        let i_size = cfg.intermediate_size;
216
217        // No TP here.
218        let gate_up_proj = mistralrs_quant::linear_no_bias_static_lora(
219            hidden_size,
220            2 * i_size,
221            cfg.loras(),
222            vb.pp("gate_up_proj"),
223        )?;
224
225        let down_proj = mistralrs_quant::linear_no_bias_static_lora(
226            i_size,
227            hidden_size,
228            cfg.loras(),
229            vb.pp("down_proj"),
230        )?;
231
232        Ok(Self {
233            gate_up_proj,
234            down_proj,
235            act_fn: cfg.hidden_act,
236            i_size,
237        })
238    }
239
240    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
241        let original_dtype = xs.dtype();
242        let mut xs = xs.clone();
243        if let Some(t) = self.gate_up_proj.quantized_act_type() {
244            xs = xs.to_dtype(t)?;
245        }
246        let up_states = MatMul.qmethod_matmul(&xs, &*self.gate_up_proj)?;
247        let gate = up_states.narrow(D::Minus1, 0, self.i_size)?;
248        let up_states = up_states.narrow(D::Minus1, self.i_size, self.i_size)?;
249        let up_states = (up_states * gate.apply(&self.act_fn))?;
250        let mut res = MatMul.qmethod_matmul(&up_states, &*self.down_proj)?;
251        if self.gate_up_proj.quantized_act_type().is_some() {
252            res = res.to_dtype(original_dtype)?;
253        }
254        Ok(res)
255    }
256}
257
258struct DecoderLayer {
259    input_layernorm: RmsNorm,
260    post_attention_layernorm: RmsNorm,
261    mlp: Mlp,
262    self_attn: Attention,
263}
264
265impl DecoderLayer {
266    fn new(
267        rotary_emb: Arc<Phi4MMRotaryEmbedding>,
268        cfg: &Phi4MMConfig,
269        vb: ShardedVarBuilder,
270        mapper: &dyn DeviceMapper,
271        layer_idx: usize,
272        loading_isq: bool,
273        paged_attn: Option<PagedAttention>,
274    ) -> Result<Self> {
275        let self_attn = Attention::new(
276            rotary_emb,
277            cfg,
278            mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
279            paged_attn,
280        )?;
281        let mlp = Mlp::new(cfg, mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq))?;
282        let input_layernorm = RmsNorm::new(
283            cfg.hidden_size,
284            cfg.rms_norm_eps,
285            mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
286        )?;
287        let post_attention_layernorm = RmsNorm::new(
288            cfg.hidden_size,
289            cfg.rms_norm_eps,
290            mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
291        )?;
292
293        Ok(Self {
294            input_layernorm,
295            post_attention_layernorm,
296            mlp,
297            self_attn,
298        })
299    }
300
301    #[allow(clippy::too_many_arguments)]
302    fn forward(
303        &self,
304        xs: &Tensor,
305        attention_mask: Option<&Tensor>,
306        seqlen_offsets: &[usize],
307        position_ids: &[usize],
308        kv_cache: &mut KvCache,
309        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
310        flash_params: &FlashParams,
311    ) -> Result<Tensor> {
312        let residual = xs;
313        let xs = self.input_layernorm.forward(xs)?;
314        let xs = self.self_attn.forward(
315            &xs,
316            attention_mask,
317            seqlen_offsets,
318            position_ids,
319            kv_cache,
320            metadata,
321            flash_params,
322        )?;
323        let xs = (xs + residual)?;
324        let residual = &xs;
325        let xs = self
326            .mlp
327            .forward(&xs.apply(&self.post_attention_layernorm)?)?;
328        residual + xs
329    }
330}
331
332pub struct Phi4MMModel {
333    embed_tokens: candle_nn::Embedding,
334    embed_tokens_extend: Phi4MMImageAudioEmbedding,
335    layers: Vec<DecoderLayer>,
336    norm: RmsNorm,
337    lm_head: Arc<dyn QuantMethod>,
338    device: Device,
339    cache: EitherCache,
340    max_seq_len: usize,
341    mapper: Box<dyn DeviceMapper + Send + Sync>,
342    sliding_window: Option<usize>,
343    cfg: ModelConfigMetadata,
344}
345
346impl Phi4MMModel {
347    pub fn new(
348        cfg: &Phi4MMConfig,
349        vb: ShardedVarBuilder,
350        _is_gptx: bool,
351        normal_loading_metadata: NormalLoadingMetadata,
352        attention_mechanism: AttentionImplementation,
353    ) -> Result<Self> {
354        let mapper = normal_loading_metadata.mapper;
355        let vb_m = vb.pp("model");
356
357        let embed_tokens = layers::embedding(
358            cfg.vocab_size,
359            cfg.hidden_size,
360            mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
361        )?;
362
363        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
364        let vb_l = vb_m.pp("layers");
365        let mut ropes = HashMap::new();
366        for layer_idx in 0..cfg.num_hidden_layers {
367            let device = mapper
368                .device_for(layer_idx, false)
369                .unwrap_or(&normal_loading_metadata.real_device);
370            ropes.insert(
371                device.location(),
372                Arc::new(Phi4MMRotaryEmbedding::new(vb.dtype(), cfg, device)?),
373            );
374        }
375        for layer_idx in NiceProgressBar::<_, 'b'>(
376            0..cfg.num_hidden_layers,
377            "Loading repeating layers",
378            &normal_loading_metadata.multi_progress,
379        ) {
380            let device = mapper
381                .device_for(layer_idx, false)
382                .unwrap_or(&normal_loading_metadata.real_device);
383            let rotary_emb = ropes
384                .get(&device.location())
385                .expect("No RoPE for device location!")
386                .clone();
387            let paged_attn = match &attention_mechanism {
388                AttentionImplementation::Eager => None,
389                AttentionImplementation::PagedAttention => {
390                    Some(PagedAttention::new(cfg.head_dim(), device, None)?)
391                }
392            };
393            let layer = DecoderLayer::new(
394                rotary_emb.clone(),
395                cfg,
396                vb_l.pp(layer_idx),
397                &*mapper,
398                layer_idx,
399                normal_loading_metadata.loading_isq,
400                paged_attn,
401            )?;
402            layers.push(layer)
403        }
404        let norm = RmsNorm::new(
405            cfg.hidden_size,
406            cfg.rms_norm_eps,
407            mapper.set_nm_device(vb_m.pp("norm"), false),
408        )?;
409        let lm_head = if !cfg.tie_word_embeddings {
410            ReplicatedLayer::new(
411                cfg.hidden_size,
412                cfg.vocab_size,
413                &None,
414                false,
415                mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
416            )?
417        } else {
418            ReplicatedLayer::from_linear(candle_nn::Linear::new(
419                mapper.cast_nm_device(
420                    embed_tokens.embeddings(),
421                    normal_loading_metadata.loading_isq,
422                )?,
423                None,
424            ))?
425        };
426
427        let embed_tokens_extend = Phi4MMImageAudioEmbedding::new(
428            cfg,
429            embed_tokens.clone(),
430            mapper.set_nm_device(vb_m.pp("embed_tokens_extend"), false),
431        )?;
432
433        Ok(Self {
434            layers,
435            norm,
436            lm_head,
437            device: normal_loading_metadata.real_device,
438            cache: EitherCache::Normal(NormalCache::new_sliding(
439                cfg.num_hidden_layers,
440                cfg.max_position_embeddings,
441                cfg.sliding_window,
442            )),
443            max_seq_len: cfg.max_position_embeddings,
444            sliding_window: cfg.sliding_window,
445            embed_tokens,
446            cfg: ModelConfigMetadata {
447                max_seq_len: cfg.max_position_embeddings,
448                num_layers: cfg.num_hidden_layers,
449                hidden_size: cfg.hidden_size,
450                num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
451                num_kv_heads: (cfg.num_key_value_heads() / mapper.get_comm_for(0)?.world_size())
452                    .max(1),
453                sliding_window: cfg.sliding_window,
454                k_head_dim: cfg.head_dim(),
455                v_head_dim: cfg.head_dim(),
456            },
457            mapper,
458            embed_tokens_extend,
459        })
460    }
461
462    #[allow(clippy::too_many_arguments)]
463    pub fn forward(
464        &self,
465        input_ids: &Tensor,
466        input_image_embeds: Option<Tensor>,
467        image_attention_mask: Option<Tensor>,
468        seqlen_offsets: &[usize],
469        position_ids: &[usize],
470        context_lens: Vec<(usize, usize)>,
471        image_sizes: Option<Vec<(u32, u32)>>,
472        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
473        flash_params: &FlashParams,
474    ) -> Result<Tensor> {
475        let mut xs = if let Some(input_image_embeds) = &input_image_embeds {
476            self.embed_tokens_extend.forward(
477                input_ids,
478                input_image_embeds,
479                image_attention_mask.as_ref(),
480                image_sizes,
481            )?
482        } else {
483            self.embed_tokens.forward(input_ids)?
484        };
485        let cache = &mut self.cache.normal().0;
486        let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
487            input_ids,
488            metadata
489                .as_ref()
490                .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
491                .unwrap_or(&*cache as &dyn PastKvLenCache),
492            self.sliding_window,
493            xs.dtype(),
494            self.cfg.num_attn_heads,
495        )?;
496        let attention_mask = attention_mask.filter(|_| {
497            metadata
498                .as_ref()
499                .map(|(_, meta)| meta.is_first_prompt_chunk)
500                .unwrap_or(true)
501        });
502
503        for (i, layer) in self.layers.iter().enumerate() {
504            xs = self.mapper.map(xs, i)?;
505            xs = layer.forward(
506                &xs,
507                attention_mask
508                    .as_ref()
509                    .map(|m| m.to_device(xs.device()).unwrap())
510                    .as_ref(),
511                seqlen_offsets,
512                position_ids,
513                &mut cache[i],
514                metadata
515                    .as_ref()
516                    .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
517                flash_params,
518            )?
519        }
520        let xs = xs.to_device(&self.device)?;
521        let mut xs = xs.apply(&self.norm)?;
522        if let Some(t) = self.lm_head.quantized_act_type() {
523            xs = xs.to_dtype(t)?;
524        }
525        extract_logits(&MatMul.qmethod_matmul(&xs, &*self.lm_head)?, context_lens)
526    }
527}
528
529#[derive(Default)]
530pub(crate) struct Phi4MMVisionSpecificArgs {
531    pub image_sizes: Option<Vec<(u32, u32)>>,
532    pub input_image_embeds: Option<Tensor>,
533    pub image_attention_mask: Option<Tensor>,
534}
535
536impl VisionModel for Phi4MMModel {
537    fn forward(
538        &self,
539        input_ids: &Tensor,
540        _pixel_values: Option<Tensor>,
541        seqlen_offsets: &[usize],
542        context_lens: Vec<(usize, usize)>,
543        position_ids: Vec<usize>,
544        model_specific_args: Box<dyn Any>,
545        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
546        flash_params: &FlashParams,
547    ) -> Result<Tensor> {
548        let Phi4MMVisionSpecificArgs {
549            image_sizes,
550            image_attention_mask,
551            input_image_embeds,
552        } = *model_specific_args
553            .downcast()
554            .expect("Cannot downcast into `Phi4MMVisionSpecificArgs`");
555        self.forward(
556            input_ids,
557            input_image_embeds,
558            image_attention_mask,
559            seqlen_offsets,
560            &position_ids,
561            context_lens,
562            image_sizes,
563            metadata,
564            flash_params,
565        )
566    }
567    fn cache(&self) -> &EitherCache {
568        &self.cache
569    }
570    fn cache_mut(&mut self) -> &mut EitherCache {
571        &mut self.cache
572    }
573    fn device(&self) -> &Device {
574        &self.device
575    }
576    fn max_seq_len(&self) -> usize {
577        self.max_seq_len
578    }
579    fn has_conv2d(&self) -> bool {
580        true
581    }
582    fn config(&self) -> &ModelConfigMetadata {
583        &self.cfg
584    }
585    fn default_model_specific_args(&self, _input_ids: &Tensor) -> Box<dyn Any> {
586        Box::new(Phi4MMVisionSpecificArgs::default())
587    }
588}
589
590impl IsqModel for Phi4MMModel {
591    fn get_layers(
592        &mut self,
593    ) -> (
594        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
595        &dyn DeviceMapper,
596    ) {
597        let mut tensors = Vec::new();
598        tensors.push((&mut self.lm_head, None));
599        for (i, layer) in self.layers.iter_mut().enumerate() {
600            tensors.push((&mut layer.self_attn.qkv_proj, Some(i)));
601            tensors.push((&mut layer.self_attn.o_proj, Some(i)));
602            tensors.push((&mut layer.mlp.gate_up_proj, Some(i)));
603            tensors.push((&mut layer.mlp.down_proj, Some(i)));
604        }
605        (tensors, &*self.mapper)
606    }
607
608    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
609        let uvb = UnVarBuilder::new();
610
611        let uvb_m = uvb.pp("model");
612        uvb_m.pp("embed_tokens").add(&self.embed_tokens);
613        uvb_m.pp("norm").add(&self.norm);
614        uvb_m
615            .pp("embed_tokens_extend")
616            .extend(self.embed_tokens_extend.residual_tensors());
617
618        for (layer_idx, layer) in self.layers.iter().enumerate() {
619            let uvb_l = uvb_m.pp("layers").pp(layer_idx);
620            uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
621            uvb_l
622                .pp("post_attention_layernorm")
623                .add(&layer.post_attention_layernorm);
624        }
625
626        uvb.to_safetensors()
627    }
628}
629
630impl AnyMoeBaseModelMixin for Phi4MMModel {}