mistralrs_core/vision_models/qwen2vl/
text.rs

1use std::{collections::HashMap, sync::Arc};
2
3use candle_core::{DType, Device, Result, Tensor};
4use candle_nn::{Embedding, Module};
5use mistralrs_quant::{
6    ColumnParallelLayer, QuantMethod, ReplicatedLayer, RowParallelLayer, ShardedVarBuilder,
7};
8
9use crate::{
10    attention::SdpaParams,
11    device_map::DeviceMapper,
12    layers::{self, Activation, F32RmsNorm, Qwen2VLRotaryEmbedding, Sdpa},
13    paged_attention::{AttentionImplementation, ModelConfigMetadata},
14    pipeline::{
15        extract_logits, text_models_inputs_processor::FlashParams, EitherCache, IsqModel, KvCache,
16        NormalCache, NormalLoadingMetadata,
17    },
18    utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
19};
20
21use super::config::Config;
22
23struct Mlp {
24    gate_proj: Arc<dyn QuantMethod>,
25    up_proj: Arc<dyn QuantMethod>,
26    down_proj: Arc<dyn QuantMethod>,
27    act_fn: Activation,
28}
29
30impl Mlp {
31    fn new(cfg: &Config, vb: ShardedVarBuilder, comm: &Arc<mistralrs_quant::Comm>) -> Result<Self> {
32        let hidden_sz = cfg.hidden_size;
33        let intermediate_sz = cfg.intermediate_size;
34        let gate_proj = ColumnParallelLayer::new(
35            hidden_sz,
36            intermediate_sz,
37            &cfg.quantization_config,
38            false,
39            comm,
40            vb.pp("gate_proj"),
41        )?;
42        let up_proj = ColumnParallelLayer::new(
43            hidden_sz,
44            intermediate_sz,
45            &cfg.quantization_config,
46            false,
47            comm,
48            vb.pp("up_proj"),
49        )?;
50        let down_proj = RowParallelLayer::new(
51            intermediate_sz,
52            hidden_sz,
53            &cfg.quantization_config,
54            false,
55            comm,
56            vb.pp("down_proj"),
57        )?;
58        Ok(Self {
59            gate_proj,
60            up_proj,
61            down_proj,
62            act_fn: cfg.hidden_act,
63        })
64    }
65
66    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
67        let original_dtype = xs.dtype();
68        let mut xs = xs.clone();
69        if let Some(t) = self.gate_proj.quantized_act_type() {
70            xs = xs.to_dtype(t)?;
71        }
72        let lhs = self.gate_proj.forward(&xs)?.apply(&self.act_fn)?;
73        let rhs = self.up_proj.forward(&xs)?;
74        self.down_proj
75            .forward(&(lhs * rhs)?)?
76            .to_dtype(original_dtype)
77    }
78}
79
80struct Attention {
81    q_proj: Arc<dyn QuantMethod>,
82    k_proj: Arc<dyn QuantMethod>,
83    v_proj: Arc<dyn QuantMethod>,
84    o_proj: Arc<dyn QuantMethod>,
85    num_heads: usize,
86    num_kv_heads: usize,
87    head_dim: usize,
88    rotary_emb: Arc<Qwen2VLRotaryEmbedding>,
89    sdpa_params: SdpaParams,
90}
91
92impl Attention {
93    fn new(
94        rotary_emb: Arc<Qwen2VLRotaryEmbedding>,
95        cfg: &Config,
96        vb: ShardedVarBuilder,
97        comm: &Arc<mistralrs_quant::Comm>,
98    ) -> Result<Self> {
99        let hidden_sz = cfg.hidden_size;
100        let num_heads = cfg.num_attention_heads;
101        let num_kv_heads = cfg.num_key_value_heads;
102        let head_dim = hidden_sz / num_heads;
103        let q_proj = ColumnParallelLayer::new(
104            hidden_sz,
105            num_heads * head_dim,
106            &cfg.quantization_config,
107            true,
108            comm,
109            vb.pp("q_proj"),
110        )?;
111        let kv_shard = mistralrs_quant::compute_kv_shard(
112            cfg.num_key_value_heads,
113            cfg.hidden_size / cfg.num_attention_heads,
114            comm,
115        );
116        let k_proj = ColumnParallelLayer::new_with_shard(
117            hidden_sz,
118            num_kv_heads * head_dim,
119            &cfg.quantization_config,
120            true,
121            comm,
122            kv_shard,
123            vb.pp("k_proj"),
124        )?;
125        let v_proj = ColumnParallelLayer::new_with_shard(
126            hidden_sz,
127            num_kv_heads * head_dim,
128            &cfg.quantization_config,
129            true,
130            comm,
131            kv_shard,
132            vb.pp("v_proj"),
133        )?;
134        let o_proj = RowParallelLayer::new(
135            num_heads * head_dim,
136            hidden_sz,
137            &cfg.quantization_config,
138            false,
139            comm,
140            vb.pp("o_proj"),
141        )?;
142        Ok(Self {
143            q_proj,
144            k_proj,
145            v_proj,
146            o_proj,
147            num_heads: num_heads / comm.world_size(),
148            num_kv_heads: (num_kv_heads / comm.world_size()).max(1),
149            head_dim,
150            rotary_emb,
151            sdpa_params: SdpaParams {
152                n_kv_groups: mistralrs_quant::compute_n_kv_groups(
153                    cfg.num_key_value_heads,
154                    cfg.num_attention_heads,
155                    comm,
156                ),
157                use_flash_attn: false,
158                softcap: None,
159                softmax_scale: 1.0 / (head_dim as f32).sqrt(),
160                sliding_window: None,
161            },
162        })
163    }
164
165    #[allow(clippy::too_many_arguments)]
166    fn forward(
167        &self,
168        xs: &Tensor,
169        attention_mask: Option<&Tensor>,
170        cos_sin: &(Tensor, Tensor),
171        kv_cache: &mut KvCache,
172        flash_params: &FlashParams,
173    ) -> Result<Tensor> {
174        let (b_sz, q_len, _) = xs.dims3()?;
175
176        let original_dtype = xs.dtype();
177        let mut xs = xs.clone();
178        if let Some(t) = self.q_proj.quantized_act_type() {
179            xs = xs.to_dtype(t)?;
180        }
181        let mut q = self.q_proj.forward(&xs)?;
182        let mut k = self.k_proj.forward(&xs)?;
183        let mut v = self.v_proj.forward(&xs)?;
184        if self.q_proj.quantized_act_type().is_some() {
185            q = q.to_dtype(original_dtype)?;
186            k = k.to_dtype(original_dtype)?;
187            v = v.to_dtype(original_dtype)?;
188        }
189
190        let (mut q, mut k, v) = if q_len != 1 {
191            let q = q
192                .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
193                .transpose(1, 2)?;
194            let k = k
195                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
196                .transpose(1, 2)?;
197            let v = v
198                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
199                .transpose(1, 2)?;
200            (q, k, v)
201        } else {
202            let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
203            let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
204            let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
205            (q, k, v)
206        };
207
208        self.rotary_emb.forward(cos_sin, &mut q, &mut k)?;
209
210        let mut attn_output = {
211            let (k, v) = kv_cache.append(&k, &v)?;
212
213            Sdpa.run_attention(
214                &q.contiguous()?.to_dtype(DType::F32)?,
215                &k.contiguous()?.to_dtype(DType::F32)?,
216                &v.contiguous()?.to_dtype(DType::F32)?,
217                attention_mask
218                    .map(|mask| mask.to_dtype(DType::F32).unwrap())
219                    .as_ref(),
220                Some(flash_params),
221                &self.sdpa_params,
222            )?
223            .to_dtype(q.dtype())?
224        };
225
226        if let Some(t) = self.q_proj.quantized_act_type() {
227            attn_output = attn_output.to_dtype(t)?;
228        }
229        attn_output = if attention_mask.is_some() {
230            attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?
231        } else {
232            attn_output.reshape((b_sz, q_len, ()))?
233        };
234        let mut res = self.o_proj.forward(&attn_output)?;
235        if self.q_proj.quantized_act_type().is_some() {
236            res = res.to_dtype(original_dtype)?;
237        }
238        Ok(res)
239    }
240}
241
242pub struct DecoderLayer {
243    self_attn: Attention,
244    mlp: Mlp,
245    input_layernorm: F32RmsNorm,
246    post_attention_layernorm: F32RmsNorm,
247}
248
249impl DecoderLayer {
250    fn new(
251        rotary_emb: Arc<Qwen2VLRotaryEmbedding>,
252        cfg: &Config,
253        vb: ShardedVarBuilder,
254        mapper: &dyn DeviceMapper,
255        layer_idx: usize,
256        loading_isq: bool,
257        comm: &Arc<mistralrs_quant::Comm>,
258    ) -> Result<Self> {
259        let self_attn = Attention::new(
260            rotary_emb,
261            cfg,
262            mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
263            comm,
264        )?;
265        let mlp = Mlp::new(
266            cfg,
267            mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq),
268            comm,
269        )?;
270        let input_layernorm = F32RmsNorm::new(
271            cfg.hidden_size,
272            cfg.rms_norm_eps,
273            mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
274        )?;
275        let post_attention_layernorm = F32RmsNorm::new(
276            cfg.hidden_size,
277            cfg.rms_norm_eps,
278            mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
279        )?;
280        Ok(Self {
281            self_attn,
282            mlp,
283            input_layernorm,
284            post_attention_layernorm,
285        })
286    }
287
288    #[allow(clippy::too_many_arguments)]
289    fn forward(
290        &self,
291        xs: &Tensor,
292        attention_mask: Option<&Tensor>,
293        cos_sin: &(Tensor, Tensor),
294        kv_cache: &mut KvCache,
295        flash_params: &FlashParams,
296    ) -> Result<Tensor> {
297        let residual = xs;
298        let xs = self.input_layernorm.forward(xs)?;
299        let xs = self
300            .self_attn
301            .forward(&xs, attention_mask, cos_sin, kv_cache, flash_params)?;
302        let xs = (xs + residual)?;
303        let residual = &xs;
304        let xs = self
305            .mlp
306            .forward(&xs.apply(&self.post_attention_layernorm)?)?;
307        residual + xs
308    }
309}
310
311pub struct Qwen2VLTextModel {
312    embed_tokens: Embedding,
313    pub(super) norm: F32RmsNorm,
314    layers: Vec<DecoderLayer>,
315    mapper: Box<dyn DeviceMapper + Send + Sync>,
316    lm_head: Arc<dyn QuantMethod>,
317    pub(super) cache: EitherCache,
318    pub(super) cfg: ModelConfigMetadata,
319    pub(super) device: Device,
320    pub(super) dtype: DType,
321    pub(super) max_seq_len: usize,
322}
323
324impl Qwen2VLTextModel {
325    pub fn new(
326        cfg: &Config,
327        vb: ShardedVarBuilder,
328        _is_gptx: bool,
329        normal_loading_metadata: NormalLoadingMetadata,
330        attention_mechanism: AttentionImplementation,
331    ) -> Result<Self> {
332        if !matches!(attention_mechanism, AttentionImplementation::Eager) {
333            candle_core::bail!("Expected eager attention implementation");
334        }
335        let mapper = normal_loading_metadata.mapper;
336        let vb_m = vb.pp("model");
337
338        let embed_tokens = layers::embedding(
339            cfg.vocab_size,
340            cfg.hidden_size,
341            mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
342        )?;
343        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
344        let head_dim = cfg.hidden_size / cfg.num_attention_heads;
345
346        let mut ropes = HashMap::new();
347        for layer_idx in 0..cfg.num_hidden_layers {
348            let device = mapper
349                .device_for(layer_idx, false)
350                .unwrap_or(&normal_loading_metadata.real_device);
351            ropes.insert(
352                device.location(),
353                Arc::new(Qwen2VLRotaryEmbedding::new(
354                    cfg.rope_theta as f32,
355                    head_dim,
356                    device,
357                    cfg.rope_scaling.mrope_section.clone(),
358                )?),
359            );
360        }
361
362        let vb_l = vb_m.pp("layers");
363        for layer_idx in NiceProgressBar::<_, 'b'>(
364            0..cfg.num_hidden_layers,
365            "Loading repeating layers",
366            &normal_loading_metadata.multi_progress,
367        ) {
368            let device = mapper
369                .device_for(layer_idx, false)
370                .unwrap_or(&normal_loading_metadata.real_device);
371            let rotary_emb = ropes
372                .get(&device.location())
373                .expect("No RoPE for device location!")
374                .clone();
375            let comm = mapper.get_comm_for(layer_idx)?;
376            let layer = DecoderLayer::new(
377                rotary_emb.clone(),
378                cfg,
379                vb_l.pp(layer_idx),
380                &*mapper,
381                layer_idx,
382                normal_loading_metadata.loading_isq,
383                &comm,
384            )?;
385            layers.push(layer)
386        }
387        let norm = F32RmsNorm::new(
388            cfg.hidden_size,
389            cfg.rms_norm_eps,
390            mapper.set_nm_device(vb_m.pp("norm"), false),
391        )?;
392        let lm_head = if !cfg.tie_word_embeddings {
393            ReplicatedLayer::new(
394                cfg.hidden_size,
395                cfg.vocab_size,
396                &None,
397                false,
398                mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
399            )?
400        } else {
401            ReplicatedLayer::from_linear(candle_nn::Linear::new(
402                mapper.cast_nm_device(
403                    embed_tokens.embeddings(),
404                    normal_loading_metadata.loading_isq,
405                )?,
406                None,
407            ))?
408        };
409        Ok(Self {
410            embed_tokens,
411            norm,
412            layers,
413            lm_head,
414            cache: EitherCache::Normal(NormalCache::new(
415                cfg.num_hidden_layers,
416                cfg.max_position_embeddings,
417            )),
418            max_seq_len: cfg.max_position_embeddings,
419            cfg: ModelConfigMetadata {
420                max_seq_len: cfg.max_position_embeddings,
421                num_layers: cfg.num_hidden_layers,
422                hidden_size: cfg.hidden_size,
423                num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
424                num_kv_heads: (cfg.num_key_value_heads / mapper.get_comm_for(0)?.world_size())
425                    .max(1),
426                sliding_window: cfg.sliding_window,
427                k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
428                v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
429            },
430            device: normal_loading_metadata.real_device.clone(),
431            dtype: vb.dtype(),
432            mapper,
433        })
434    }
435
436    pub fn embed_tokens(&self, input_ids: &Tensor) -> Result<Tensor> {
437        self.embed_tokens.forward(input_ids)
438    }
439
440    pub fn forward_embeds(
441        &self,
442        mut xs: Tensor,
443        attention_mask: Option<&Tensor>,
444        position_ids: &Tensor,
445        context_lens: Vec<(usize, usize)>,
446        flash_params: &FlashParams,
447    ) -> Result<Tensor> {
448        let cache = &mut self.cache.normal().0;
449        let cos_sin = self.layers[0]
450            .self_attn
451            .rotary_emb
452            .compute_cos_sin(position_ids, xs.dtype())?;
453
454        for (i, layer) in self.layers.iter().enumerate() {
455            xs = self.mapper.map(xs, i)?;
456            xs = layer.forward(
457                &xs,
458                attention_mask
459                    .as_ref()
460                    .map(|m| m.to_device(xs.device()).unwrap())
461                    .as_ref(),
462                &cos_sin,
463                &mut cache[i],
464                flash_params,
465            )?
466        }
467        let xs = xs.to_device(&self.device)?;
468        let mut xs = xs.apply(&self.norm)?;
469        if let Some(t) = self.lm_head.quantized_act_type() {
470            xs = xs.to_dtype(t)?;
471        }
472        extract_logits(&self.lm_head.forward(&xs)?, context_lens)
473    }
474}
475
476impl IsqModel for Qwen2VLTextModel {
477    fn get_layers(
478        &mut self,
479    ) -> (
480        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
481        &dyn DeviceMapper,
482    ) {
483        let mut tensors = Vec::new();
484        tensors.push((&mut self.lm_head, None));
485        for (i, layer) in self.layers.iter_mut().enumerate() {
486            tensors.push((&mut layer.self_attn.q_proj, Some(i)));
487            tensors.push((&mut layer.self_attn.k_proj, Some(i)));
488            tensors.push((&mut layer.self_attn.v_proj, Some(i)));
489            tensors.push((&mut layer.self_attn.o_proj, Some(i)));
490            tensors.push((&mut layer.mlp.gate_proj, Some(i)));
491            tensors.push((&mut layer.mlp.up_proj, Some(i)));
492            tensors.push((&mut layer.mlp.down_proj, Some(i)));
493        }
494        (tensors, &*self.mapper)
495    }
496
497    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
498        let uvb = UnVarBuilder::new();
499
500        let uvb_m = uvb.pp("model");
501        uvb_m.pp("embed_tokens").add(&self.embed_tokens);
502        uvb_m.pp("norm").add(&self.norm);
503
504        for (layer_idx, layer) in self.layers.iter().enumerate() {
505            let uvb_l = uvb_m.pp("layers").pp(layer_idx);
506            uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
507            uvb_l
508                .pp("post_attention_layernorm")
509                .add(&layer.post_attention_layernorm);
510        }
511
512        uvb.to_safetensors()
513    }
514}