mistralrs_core/vision_models/phi4/
mod.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{any::Any, collections::HashMap, sync::Arc};
4
5use candle_core::{Device, Result, Tensor, D};
6use candle_nn::Module;
7use mistralrs_quant::{MatMul, QuantMethod, ReplicatedLayer, ShardedVarBuilder};
8use mm_embedding::Phi4MMImageAudioEmbedding;
9
10use crate::{
11    amoe::AnyMoeBaseModelMixin,
12    attention::SdpaParams,
13    device_map::DeviceMapper,
14    layers::{self, Activation, CausalMasker, Phi4MMRotaryEmbedding, RmsNorm, Sdpa},
15    layers_masker::PastKvLenCache,
16    paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
17    pipeline::{
18        extract_logits,
19        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
20        EitherCache, IsqModel, KvCache, NormalCache, NormalLoadingMetadata, VisionModel,
21    },
22    utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
23};
24
25mod config;
26mod image_embedding;
27pub(crate) mod inputs_processor;
28mod mm_embedding;
29
30pub(crate) use config::Phi4MMConfig;
31pub(crate) use image_embedding::PHI4_MM_VISION_CFG;
32
33struct Attention {
34    qkv_proj: Arc<dyn QuantMethod>,
35    o_proj: Arc<dyn QuantMethod>,
36    num_heads: usize,
37    num_kv_heads: usize,
38    head_dim: usize,
39    rotary_emb: Arc<Phi4MMRotaryEmbedding>,
40    paged_attn: Option<PagedAttention>,
41    sdpa_params: SdpaParams,
42}
43
44impl Attention {
45    fn new(
46        rotary_emb: Arc<Phi4MMRotaryEmbedding>,
47        cfg: &Phi4MMConfig,
48        vb: ShardedVarBuilder,
49        paged_attn: Option<PagedAttention>,
50    ) -> Result<Self> {
51        let num_heads = cfg.num_attention_heads;
52        let num_kv_heads = cfg.num_key_value_heads();
53        let head_dim = cfg.head_dim();
54        let op_size = num_heads * head_dim + 2 * num_kv_heads * head_dim;
55
56        // No TP here.
57        let qkv_proj = mistralrs_quant::linear_no_bias_static_lora(
58            cfg.hidden_size,
59            op_size,
60            cfg.loras(),
61            vb.pp("qkv_proj"),
62        )?;
63
64        let o_proj = mistralrs_quant::linear_no_bias_static_lora(
65            num_heads * head_dim,
66            cfg.hidden_size,
67            cfg.loras(),
68            vb.pp("o_proj"),
69        )?;
70
71        Ok(Self {
72            qkv_proj,
73            o_proj,
74            rotary_emb,
75            num_heads,
76            num_kv_heads,
77            head_dim,
78            paged_attn,
79            sdpa_params: SdpaParams {
80                n_kv_groups: num_heads / num_kv_heads,
81                use_flash_attn: cfg.use_flash_attn,
82                softcap: None,
83                softmax_scale: 1.0 / (head_dim as f32).sqrt(),
84                sliding_window: cfg.sliding_window,
85            },
86        })
87    }
88
89    #[allow(clippy::too_many_arguments)]
90    fn forward(
91        &self,
92        xs: &Tensor,
93        attention_mask: Option<&Tensor>,
94        seqlen_offsets: &[usize],
95        position_ids: &[usize],
96        kv_cache: &mut KvCache,
97        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
98        flash_params: &FlashParams,
99    ) -> Result<Tensor> {
100        let (b_sz, q_len, _) = xs.dims3()?;
101
102        let original_dtype = xs.dtype();
103        let mut xs = xs.clone();
104        if let Some(t) = self.qkv_proj.quantized_act_type() {
105            xs = xs.to_dtype(t)?;
106        }
107        let mut qkv = MatMul.qmethod_matmul(&xs, &*self.qkv_proj)?;
108        if self.qkv_proj.quantized_act_type().is_some() {
109            qkv = qkv.to_dtype(original_dtype)?;
110        }
111        let query_pos = self.num_heads * self.head_dim;
112        let q = qkv.narrow(D::Minus1, 0, query_pos)?;
113        let k = qkv.narrow(D::Minus1, query_pos, self.num_kv_heads * self.head_dim)?;
114        let v = qkv.narrow(
115            D::Minus1,
116            query_pos + self.num_kv_heads * self.head_dim,
117            self.num_kv_heads * self.head_dim,
118        )?;
119
120        let (q, k, v) = if q_len != 1 {
121            let q = q
122                .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
123                .transpose(1, 2)?;
124            let k = k
125                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
126                .transpose(1, 2)?;
127            let v = v
128                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
129                .transpose(1, 2)?;
130            (q, k, v)
131        } else {
132            let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
133            let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
134            let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
135            (q, k, v)
136        };
137
138        let (q, k) = self
139            .rotary_emb
140            .forward(&q, &k, seqlen_offsets, position_ids)?;
141
142        let mut attn_output = match &self.paged_attn {
143            Some(paged_attn) => match metadata {
144                Some(((key_cache, value_cache), input_metadata)) => paged_attn.forward(
145                    &q,
146                    &k.contiguous()?,
147                    &v.contiguous()?,
148                    attention_mask,
149                    Some(key_cache),
150                    Some(value_cache),
151                    input_metadata,
152                    &self.sdpa_params,
153                    Some(flash_params),
154                )?,
155                None => {
156                    // If we don't have metadata, we are most likely generating an imatrix so we don't want to populate that.
157                    // Generating the dummy metadata with the assumption that we are not generating text (only processing prompts).
158                    let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
159                    // Sanity check.
160                    assert!(attention_mask.is_some());
161                    paged_attn.forward(
162                        &q,
163                        &k.contiguous()?,
164                        &v.contiguous()?,
165                        attention_mask,
166                        None,
167                        None,
168                        &input_metadata,
169                        &self.sdpa_params,
170                        Some(flash_params),
171                    )?
172                }
173            },
174            None => {
175                let (k, v) = kv_cache.append(&k, &v)?;
176
177                Sdpa.run_attention(
178                    &q,
179                    &k,
180                    &v,
181                    attention_mask,
182                    Some(flash_params),
183                    &self.sdpa_params,
184                )?
185            }
186        };
187
188        if let Some(t) = self.qkv_proj.quantized_act_type() {
189            attn_output = attn_output.to_dtype(t)?;
190        }
191        attn_output = if attention_mask.is_some() {
192            attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?
193        } else {
194            attn_output.reshape((b_sz, q_len, ()))?
195        };
196        let mut res = MatMul.qmethod_matmul(&attn_output, &*self.o_proj)?;
197        if self.qkv_proj.quantized_act_type().is_some() {
198            res = res.to_dtype(original_dtype)?;
199        }
200        Ok(res)
201    }
202}
203
204#[derive(Clone)]
205struct Mlp {
206    gate_up_proj: Arc<dyn QuantMethod>,
207    down_proj: Arc<dyn QuantMethod>,
208    act_fn: Activation,
209    i_size: usize,
210}
211
212impl Mlp {
213    fn new(cfg: &Phi4MMConfig, vb: ShardedVarBuilder) -> Result<Self> {
214        let hidden_size = cfg.hidden_size;
215        let i_size = cfg.intermediate_size;
216
217        // No TP here.
218        let gate_up_proj = mistralrs_quant::linear_no_bias_static_lora(
219            hidden_size,
220            2 * i_size,
221            cfg.loras(),
222            vb.pp("gate_up_proj"),
223        )?;
224
225        let down_proj = mistralrs_quant::linear_no_bias_static_lora(
226            i_size,
227            hidden_size,
228            cfg.loras(),
229            vb.pp("down_proj"),
230        )?;
231
232        Ok(Self {
233            gate_up_proj,
234            down_proj,
235            act_fn: cfg.hidden_act,
236            i_size,
237        })
238    }
239
240    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
241        let original_dtype = xs.dtype();
242        let mut xs = xs.clone();
243        if let Some(t) = self.gate_up_proj.quantized_act_type() {
244            xs = xs.to_dtype(t)?;
245        }
246        let up_states = MatMul.qmethod_matmul(&xs, &*self.gate_up_proj)?;
247        let gate = up_states.narrow(D::Minus1, 0, self.i_size)?;
248        let up_states = up_states.narrow(D::Minus1, self.i_size, self.i_size)?;
249        let up_states = (up_states * gate.apply(&self.act_fn))?;
250        let mut res = MatMul.qmethod_matmul(&up_states, &*self.down_proj)?;
251        if self.gate_up_proj.quantized_act_type().is_some() {
252            res = res.to_dtype(original_dtype)?;
253        }
254        Ok(res)
255    }
256}
257
258struct DecoderLayer {
259    input_layernorm: RmsNorm,
260    post_attention_layernorm: RmsNorm,
261    mlp: Mlp,
262    self_attn: Attention,
263}
264
265impl DecoderLayer {
266    fn new(
267        rotary_emb: Arc<Phi4MMRotaryEmbedding>,
268        cfg: &Phi4MMConfig,
269        vb: ShardedVarBuilder,
270        mapper: &dyn DeviceMapper,
271        layer_idx: usize,
272        loading_isq: bool,
273        paged_attn: Option<PagedAttention>,
274    ) -> Result<Self> {
275        let self_attn = Attention::new(
276            rotary_emb,
277            cfg,
278            mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
279            paged_attn,
280        )?;
281        let mlp = Mlp::new(cfg, mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq))?;
282        let input_layernorm = RmsNorm::new(
283            cfg.hidden_size,
284            cfg.rms_norm_eps,
285            mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
286        )?;
287        let post_attention_layernorm = RmsNorm::new(
288            cfg.hidden_size,
289            cfg.rms_norm_eps,
290            mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
291        )?;
292
293        Ok(Self {
294            input_layernorm,
295            post_attention_layernorm,
296            mlp,
297            self_attn,
298        })
299    }
300
301    #[allow(clippy::too_many_arguments)]
302    fn forward(
303        &self,
304        xs: &Tensor,
305        attention_mask: Option<&Tensor>,
306        seqlen_offsets: &[usize],
307        position_ids: &[usize],
308        kv_cache: &mut KvCache,
309        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
310        flash_params: &FlashParams,
311    ) -> Result<Tensor> {
312        let residual = xs;
313        let xs = self.input_layernorm.forward(xs)?;
314        let xs = self.self_attn.forward(
315            &xs,
316            attention_mask,
317            seqlen_offsets,
318            position_ids,
319            kv_cache,
320            metadata,
321            flash_params,
322        )?;
323        let xs = (xs + residual)?;
324        let residual = &xs;
325        let xs = self
326            .mlp
327            .forward(&xs.apply(&self.post_attention_layernorm)?)?;
328        residual + xs
329    }
330}
331
332pub struct Phi4MMModel {
333    embed_tokens: candle_nn::Embedding,
334    embed_tokens_extend: Phi4MMImageAudioEmbedding,
335    layers: Vec<DecoderLayer>,
336    norm: RmsNorm,
337    lm_head: Arc<dyn QuantMethod>,
338    device: Device,
339    cache: EitherCache,
340    max_seq_len: usize,
341    mapper: Box<dyn DeviceMapper + Send + Sync>,
342    sliding_window: Option<usize>,
343    cfg: ModelConfigMetadata,
344}
345
346impl Phi4MMModel {
347    pub fn new(
348        cfg: &Phi4MMConfig,
349        vb: ShardedVarBuilder,
350        _is_gptx: bool,
351        normal_loading_metadata: NormalLoadingMetadata,
352        attention_mechanism: AttentionImplementation,
353    ) -> Result<Self> {
354        let mapper = normal_loading_metadata.mapper;
355        let vb_m = vb.pp("model");
356
357        let embed_tokens = layers::embedding(
358            cfg.vocab_size,
359            cfg.hidden_size,
360            mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
361            &cfg.quantization_config,
362        )?;
363
364        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
365        let vb_l = vb_m.pp("layers");
366        let mut ropes = HashMap::new();
367        for layer_idx in 0..cfg.num_hidden_layers {
368            let device = mapper
369                .device_for(layer_idx, false)
370                .unwrap_or(&normal_loading_metadata.real_device);
371            ropes.insert(
372                device.location(),
373                Arc::new(Phi4MMRotaryEmbedding::new(vb.dtype(), cfg, device)?),
374            );
375        }
376        for layer_idx in NiceProgressBar::<_, 'b'>(
377            0..cfg.num_hidden_layers,
378            "Loading repeating layers",
379            &normal_loading_metadata.multi_progress,
380        ) {
381            let device = mapper
382                .device_for(layer_idx, false)
383                .unwrap_or(&normal_loading_metadata.real_device);
384            let rotary_emb = ropes
385                .get(&device.location())
386                .expect("No RoPE for device location!")
387                .clone();
388            let paged_attn = match &attention_mechanism {
389                AttentionImplementation::Eager => None,
390                AttentionImplementation::PagedAttention => {
391                    Some(PagedAttention::new(cfg.head_dim(), device, None)?)
392                }
393            };
394            let layer = DecoderLayer::new(
395                rotary_emb.clone(),
396                cfg,
397                vb_l.pp(layer_idx),
398                &*mapper,
399                layer_idx,
400                normal_loading_metadata.loading_isq,
401                paged_attn,
402            )?;
403            layers.push(layer)
404        }
405        let norm = RmsNorm::new(
406            cfg.hidden_size,
407            cfg.rms_norm_eps,
408            mapper.set_nm_device(vb_m.pp("norm"), false),
409        )?;
410        let lm_head = if !cfg.tie_word_embeddings {
411            ReplicatedLayer::new(
412                cfg.hidden_size,
413                cfg.vocab_size,
414                &None,
415                false,
416                mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
417            )?
418        } else {
419            ReplicatedLayer::from_linear(candle_nn::Linear::new(
420                mapper.cast_nm_device(
421                    embed_tokens.embeddings(),
422                    normal_loading_metadata.loading_isq,
423                )?,
424                None,
425            ))?
426        };
427
428        let embed_tokens_extend = Phi4MMImageAudioEmbedding::new(
429            cfg,
430            embed_tokens.clone(),
431            mapper.set_nm_device(vb_m.pp("embed_tokens_extend"), false),
432        )?;
433
434        Ok(Self {
435            layers,
436            norm,
437            lm_head,
438            device: normal_loading_metadata.real_device,
439            cache: EitherCache::Normal(NormalCache::new_sliding(
440                cfg.num_hidden_layers,
441                cfg.max_position_embeddings,
442                cfg.sliding_window,
443            )),
444            max_seq_len: cfg.max_position_embeddings,
445            sliding_window: cfg.sliding_window,
446            embed_tokens,
447            cfg: ModelConfigMetadata {
448                max_seq_len: cfg.max_position_embeddings,
449                num_layers: cfg.num_hidden_layers,
450                hidden_size: cfg.hidden_size,
451                num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
452                num_kv_heads: (cfg.num_key_value_heads() / mapper.get_comm_for(0)?.world_size())
453                    .max(1),
454                sliding_window: cfg.sliding_window,
455                k_head_dim: cfg.head_dim(),
456                v_head_dim: cfg.head_dim(),
457            },
458            mapper,
459            embed_tokens_extend,
460        })
461    }
462
463    #[allow(clippy::too_many_arguments)]
464    pub fn forward(
465        &self,
466        input_ids: &Tensor,
467        input_image_embeds: Option<Tensor>,
468        image_attention_mask: Option<Tensor>,
469        seqlen_offsets: &[usize],
470        position_ids: &[usize],
471        context_lens: Vec<(usize, usize)>,
472        image_sizes: Option<Vec<(u32, u32)>>,
473        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
474        flash_params: &FlashParams,
475    ) -> Result<Tensor> {
476        let mut xs = if let Some(input_image_embeds) = &input_image_embeds {
477            self.embed_tokens_extend.forward(
478                input_ids,
479                input_image_embeds,
480                image_attention_mask.as_ref(),
481                image_sizes,
482            )?
483        } else {
484            self.embed_tokens.forward(input_ids)?
485        };
486        let cache = &mut self.cache.normal().0;
487        let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
488            input_ids,
489            metadata
490                .as_ref()
491                .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
492                .unwrap_or(&*cache as &dyn PastKvLenCache),
493            self.sliding_window,
494            xs.dtype(),
495            self.cfg.num_attn_heads,
496        )?;
497        let attention_mask = attention_mask.filter(|_| {
498            metadata
499                .as_ref()
500                .map(|(_, meta)| meta.is_first_prompt_chunk)
501                .unwrap_or(true)
502        });
503
504        for (i, layer) in self.layers.iter().enumerate() {
505            xs = self.mapper.map(xs, i)?;
506            xs = layer.forward(
507                &xs,
508                attention_mask
509                    .as_ref()
510                    .map(|m| m.to_device(xs.device()).unwrap())
511                    .as_ref(),
512                seqlen_offsets,
513                position_ids,
514                &mut cache[i],
515                metadata
516                    .as_ref()
517                    .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
518                flash_params,
519            )?
520        }
521        let xs = xs.to_device(&self.device)?;
522        let mut xs = xs.apply(&self.norm)?;
523        if let Some(t) = self.lm_head.quantized_act_type() {
524            xs = xs.to_dtype(t)?;
525        }
526        extract_logits(&MatMul.qmethod_matmul(&xs, &*self.lm_head)?, context_lens)
527    }
528}
529
530#[derive(Default)]
531pub(crate) struct Phi4MMVisionSpecificArgs {
532    pub image_sizes: Option<Vec<(u32, u32)>>,
533    pub input_image_embeds: Option<Tensor>,
534    pub image_attention_mask: Option<Tensor>,
535}
536
537impl VisionModel for Phi4MMModel {
538    fn forward(
539        &self,
540        input_ids: &Tensor,
541        _pixel_values: Option<Tensor>,
542        seqlen_offsets: &[usize],
543        context_lens: Vec<(usize, usize)>,
544        position_ids: Vec<usize>,
545        model_specific_args: Box<dyn Any>,
546        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
547        flash_params: &FlashParams,
548    ) -> Result<Tensor> {
549        let Phi4MMVisionSpecificArgs {
550            image_sizes,
551            image_attention_mask,
552            input_image_embeds,
553        } = *model_specific_args
554            .downcast()
555            .expect("Cannot downcast into `Phi4MMVisionSpecificArgs`");
556        self.forward(
557            input_ids,
558            input_image_embeds,
559            image_attention_mask,
560            seqlen_offsets,
561            &position_ids,
562            context_lens,
563            image_sizes,
564            metadata,
565            flash_params,
566        )
567    }
568    fn cache(&self) -> &EitherCache {
569        &self.cache
570    }
571    fn cache_mut(&mut self) -> &mut EitherCache {
572        &mut self.cache
573    }
574    fn device(&self) -> &Device {
575        &self.device
576    }
577    fn max_seq_len(&self) -> usize {
578        self.max_seq_len
579    }
580    fn has_conv2d(&self) -> bool {
581        true
582    }
583    fn config(&self) -> &ModelConfigMetadata {
584        &self.cfg
585    }
586    fn default_model_specific_args(&self, _input_ids: &Tensor) -> Box<dyn Any> {
587        Box::new(Phi4MMVisionSpecificArgs::default())
588    }
589}
590
591impl IsqModel for Phi4MMModel {
592    fn get_layers(
593        &mut self,
594    ) -> (
595        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
596        &dyn DeviceMapper,
597    ) {
598        let mut tensors = Vec::new();
599        tensors.push((&mut self.lm_head, None));
600        for (i, layer) in self.layers.iter_mut().enumerate() {
601            tensors.push((&mut layer.self_attn.qkv_proj, Some(i)));
602            tensors.push((&mut layer.self_attn.o_proj, Some(i)));
603            tensors.push((&mut layer.mlp.gate_up_proj, Some(i)));
604            tensors.push((&mut layer.mlp.down_proj, Some(i)));
605        }
606        (tensors, &*self.mapper)
607    }
608
609    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
610        let uvb = UnVarBuilder::new();
611
612        let uvb_m = uvb.pp("model");
613        uvb_m.pp("embed_tokens").add(&self.embed_tokens);
614        uvb_m.pp("norm").add(&self.norm);
615        uvb_m
616            .pp("embed_tokens_extend")
617            .extend(self.embed_tokens_extend.residual_tensors());
618
619        for (layer_idx, layer) in self.layers.iter().enumerate() {
620            let uvb_l = uvb_m.pp("layers").pp(layer_idx);
621            uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
622            uvb_l
623                .pp("post_attention_layernorm")
624                .add(&layer.post_attention_layernorm);
625        }
626
627        uvb.to_safetensors()
628    }
629}
630
631impl AnyMoeBaseModelMixin for Phi4MMModel {}