mistralrs_core/vision_models/qwen2vl/
text.rs

1use std::{collections::HashMap, sync::Arc};
2
3use candle_core::{DType, Device, Result, Tensor};
4use candle_nn::{Embedding, Module};
5use mistralrs_quant::{
6    ColumnParallelLayer, QuantMethod, ReplicatedLayer, RowParallelLayer, ShardedVarBuilder,
7};
8
9use crate::{
10    attention::SdpaParams,
11    device_map::DeviceMapper,
12    layers::{self, Activation, F32RmsNorm, Qwen2VLRotaryEmbedding, Sdpa},
13    paged_attention::{AttentionImplementation, ModelConfigMetadata},
14    pipeline::{
15        extract_logits, text_models_inputs_processor::FlashParams, EitherCache, IsqModel, KvCache,
16        NormalCache, NormalLoadingMetadata,
17    },
18    utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
19};
20
21use super::config::Config;
22
23struct Mlp {
24    gate_proj: Arc<dyn QuantMethod>,
25    up_proj: Arc<dyn QuantMethod>,
26    down_proj: Arc<dyn QuantMethod>,
27    act_fn: Activation,
28}
29
30impl Mlp {
31    fn new(cfg: &Config, vb: ShardedVarBuilder, comm: &Arc<mistralrs_quant::Comm>) -> Result<Self> {
32        let hidden_sz = cfg.hidden_size;
33        let intermediate_sz = cfg.intermediate_size;
34        let gate_proj = ColumnParallelLayer::new(
35            hidden_sz,
36            intermediate_sz,
37            &cfg.quantization_config,
38            false,
39            comm,
40            vb.pp("gate_proj"),
41        )?;
42        let up_proj = ColumnParallelLayer::new(
43            hidden_sz,
44            intermediate_sz,
45            &cfg.quantization_config,
46            false,
47            comm,
48            vb.pp("up_proj"),
49        )?;
50        let down_proj = RowParallelLayer::new(
51            intermediate_sz,
52            hidden_sz,
53            &cfg.quantization_config,
54            false,
55            comm,
56            vb.pp("down_proj"),
57        )?;
58        Ok(Self {
59            gate_proj,
60            up_proj,
61            down_proj,
62            act_fn: cfg.hidden_act,
63        })
64    }
65
66    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
67        let original_dtype = xs.dtype();
68        let mut xs = xs.clone();
69        if let Some(t) = self.gate_proj.quantized_act_type() {
70            xs = xs.to_dtype(t)?;
71        }
72        let lhs = self.gate_proj.forward(&xs)?.apply(&self.act_fn)?;
73        let rhs = self.up_proj.forward(&xs)?;
74        self.down_proj
75            .forward(&(lhs * rhs)?)?
76            .to_dtype(original_dtype)
77    }
78}
79
80struct Attention {
81    q_proj: Arc<dyn QuantMethod>,
82    k_proj: Arc<dyn QuantMethod>,
83    v_proj: Arc<dyn QuantMethod>,
84    o_proj: Arc<dyn QuantMethod>,
85    num_heads: usize,
86    num_kv_heads: usize,
87    head_dim: usize,
88    rotary_emb: Arc<Qwen2VLRotaryEmbedding>,
89    sdpa_params: SdpaParams,
90}
91
92impl Attention {
93    fn new(
94        rotary_emb: Arc<Qwen2VLRotaryEmbedding>,
95        cfg: &Config,
96        vb: ShardedVarBuilder,
97        comm: &Arc<mistralrs_quant::Comm>,
98    ) -> Result<Self> {
99        let hidden_sz = cfg.hidden_size;
100        let num_heads = cfg.num_attention_heads;
101        let num_kv_heads = cfg.num_key_value_heads;
102        let head_dim = hidden_sz / num_heads;
103        let q_proj = ColumnParallelLayer::new(
104            hidden_sz,
105            num_heads * head_dim,
106            &cfg.quantization_config,
107            true,
108            comm,
109            vb.pp("q_proj"),
110        )?;
111        let kv_shard = mistralrs_quant::compute_kv_shard(
112            cfg.num_key_value_heads,
113            cfg.hidden_size / cfg.num_attention_heads,
114            comm,
115        );
116        let k_proj = ColumnParallelLayer::new_with_shard(
117            hidden_sz,
118            num_kv_heads * head_dim,
119            &cfg.quantization_config,
120            true,
121            comm,
122            kv_shard,
123            vb.pp("k_proj"),
124        )?;
125        let v_proj = ColumnParallelLayer::new_with_shard(
126            hidden_sz,
127            num_kv_heads * head_dim,
128            &cfg.quantization_config,
129            true,
130            comm,
131            kv_shard,
132            vb.pp("v_proj"),
133        )?;
134        let o_proj = RowParallelLayer::new(
135            num_heads * head_dim,
136            hidden_sz,
137            &cfg.quantization_config,
138            false,
139            comm,
140            vb.pp("o_proj"),
141        )?;
142        Ok(Self {
143            q_proj,
144            k_proj,
145            v_proj,
146            o_proj,
147            num_heads: num_heads / comm.world_size(),
148            num_kv_heads: (num_kv_heads / comm.world_size()).max(1),
149            head_dim,
150            rotary_emb,
151            sdpa_params: SdpaParams {
152                n_kv_groups: mistralrs_quant::compute_n_kv_groups(
153                    cfg.num_key_value_heads,
154                    cfg.num_attention_heads,
155                    comm,
156                ),
157                softcap: None,
158                softmax_scale: 1.0 / (head_dim as f32).sqrt(),
159                sliding_window: None,
160            },
161        })
162    }
163
164    #[allow(clippy::too_many_arguments)]
165    fn forward(
166        &self,
167        xs: &Tensor,
168        attention_mask: Option<&Tensor>,
169        cos_sin: &(Tensor, Tensor),
170        kv_cache: &mut KvCache,
171        flash_params: &FlashParams,
172    ) -> Result<Tensor> {
173        let (b_sz, q_len, _) = xs.dims3()?;
174
175        let original_dtype = xs.dtype();
176        let mut xs = xs.clone();
177        if let Some(t) = self.q_proj.quantized_act_type() {
178            xs = xs.to_dtype(t)?;
179        }
180        let mut q = self.q_proj.forward(&xs)?;
181        let mut k = self.k_proj.forward(&xs)?;
182        let mut v = self.v_proj.forward(&xs)?;
183        if self.q_proj.quantized_act_type().is_some() {
184            q = q.to_dtype(original_dtype)?;
185            k = k.to_dtype(original_dtype)?;
186            v = v.to_dtype(original_dtype)?;
187        }
188
189        let (mut q, mut k, v) = if q_len != 1 {
190            let q = q
191                .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
192                .transpose(1, 2)?;
193            let k = k
194                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
195                .transpose(1, 2)?;
196            let v = v
197                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
198                .transpose(1, 2)?;
199            (q, k, v)
200        } else {
201            let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
202            let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
203            let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
204            (q, k, v)
205        };
206
207        self.rotary_emb.forward(cos_sin, &mut q, &mut k)?;
208
209        let mut attn_output = {
210            let (k, v) = kv_cache.append(&k, &v)?;
211
212            Sdpa.run_attention(
213                &q.contiguous()?.to_dtype(DType::F32)?,
214                &k.contiguous()?.to_dtype(DType::F32)?,
215                &v.contiguous()?.to_dtype(DType::F32)?,
216                attention_mask
217                    .map(|mask| mask.to_dtype(DType::F32).unwrap())
218                    .as_ref(),
219                Some(flash_params),
220                &self.sdpa_params,
221            )?
222            .to_dtype(q.dtype())?
223        };
224
225        if let Some(t) = self.q_proj.quantized_act_type() {
226            attn_output = attn_output.to_dtype(t)?;
227        }
228        attn_output = if attention_mask.is_some() {
229            attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?
230        } else {
231            attn_output.reshape((b_sz, q_len, ()))?
232        };
233        let mut res = self.o_proj.forward(&attn_output)?;
234        if self.q_proj.quantized_act_type().is_some() {
235            res = res.to_dtype(original_dtype)?;
236        }
237        Ok(res)
238    }
239}
240
241pub struct DecoderLayer {
242    self_attn: Attention,
243    mlp: Mlp,
244    input_layernorm: F32RmsNorm,
245    post_attention_layernorm: F32RmsNorm,
246}
247
248impl DecoderLayer {
249    fn new(
250        rotary_emb: Arc<Qwen2VLRotaryEmbedding>,
251        cfg: &Config,
252        vb: ShardedVarBuilder,
253        mapper: &dyn DeviceMapper,
254        layer_idx: usize,
255        loading_isq: bool,
256        comm: &Arc<mistralrs_quant::Comm>,
257    ) -> Result<Self> {
258        let self_attn = Attention::new(
259            rotary_emb,
260            cfg,
261            mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
262            comm,
263        )?;
264        let mlp = Mlp::new(
265            cfg,
266            mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq),
267            comm,
268        )?;
269        let input_layernorm = F32RmsNorm::new(
270            cfg.hidden_size,
271            cfg.rms_norm_eps,
272            mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
273        )?;
274        let post_attention_layernorm = F32RmsNorm::new(
275            cfg.hidden_size,
276            cfg.rms_norm_eps,
277            mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
278        )?;
279        Ok(Self {
280            self_attn,
281            mlp,
282            input_layernorm,
283            post_attention_layernorm,
284        })
285    }
286
287    #[allow(clippy::too_many_arguments)]
288    fn forward(
289        &self,
290        xs: &Tensor,
291        attention_mask: Option<&Tensor>,
292        cos_sin: &(Tensor, Tensor),
293        kv_cache: &mut KvCache,
294        flash_params: &FlashParams,
295    ) -> Result<Tensor> {
296        let residual = xs;
297        let xs = self.input_layernorm.forward(xs)?;
298        let xs = self
299            .self_attn
300            .forward(&xs, attention_mask, cos_sin, kv_cache, flash_params)?;
301        let xs = (xs + residual)?;
302        let residual = &xs;
303        let xs = self
304            .mlp
305            .forward(&xs.apply(&self.post_attention_layernorm)?)?;
306        residual + xs
307    }
308}
309
310pub struct Qwen2VLTextModel {
311    embed_tokens: Embedding,
312    pub(super) norm: F32RmsNorm,
313    layers: Vec<DecoderLayer>,
314    mapper: Box<dyn DeviceMapper + Send + Sync>,
315    lm_head: Arc<dyn QuantMethod>,
316    pub(super) cache: EitherCache,
317    pub(super) cfg: ModelConfigMetadata,
318    pub(super) device: Device,
319    pub(super) dtype: DType,
320    pub(super) max_seq_len: usize,
321}
322
323impl Qwen2VLTextModel {
324    pub fn new(
325        cfg: &Config,
326        vb: ShardedVarBuilder,
327        _is_gptx: bool,
328        normal_loading_metadata: NormalLoadingMetadata,
329        attention_mechanism: AttentionImplementation,
330    ) -> Result<Self> {
331        if !matches!(attention_mechanism, AttentionImplementation::Eager) {
332            candle_core::bail!("Expected eager attention implementation");
333        }
334        let mapper = normal_loading_metadata.mapper;
335        let vb_m = vb.pp("model");
336
337        let embed_tokens = layers::embedding(
338            cfg.vocab_size,
339            cfg.hidden_size,
340            mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
341            &cfg.quantization_config,
342        )?;
343        let head_dim = cfg.hidden_size / cfg.num_attention_heads;
344
345        let mut ropes = HashMap::new();
346        for layer_idx in 0..cfg.num_hidden_layers {
347            let device = mapper
348                .device_for(layer_idx, false)
349                .unwrap_or(&normal_loading_metadata.real_device);
350            ropes.insert(
351                device.location(),
352                Arc::new(Qwen2VLRotaryEmbedding::new(
353                    cfg.rope_theta as f32,
354                    head_dim,
355                    device,
356                    cfg.rope_scaling.mrope_section.clone(),
357                )?),
358            );
359        }
360
361        let vb_l = vb_m.pp("layers");
362        let layers = NiceProgressBar::<_, 'b'>(
363            0..cfg.num_hidden_layers,
364            "Loading repeating layers",
365            &normal_loading_metadata.multi_progress,
366        )
367        .par_iter_if_isq(|layer_idx| {
368            let device = mapper
369                .device_for(layer_idx, false)
370                .unwrap_or(&normal_loading_metadata.real_device);
371            let rotary_emb = ropes
372                .get(&device.location())
373                .expect("No RoPE for device location!")
374                .clone();
375            let comm = mapper.get_comm_for(layer_idx)?;
376            DecoderLayer::new(
377                rotary_emb.clone(),
378                cfg,
379                vb_l.pp(layer_idx),
380                &*mapper,
381                layer_idx,
382                normal_loading_metadata.loading_isq,
383                &comm,
384            )
385        })?;
386        let norm = F32RmsNorm::new(
387            cfg.hidden_size,
388            cfg.rms_norm_eps,
389            mapper.set_nm_device(vb_m.pp("norm"), false),
390        )?;
391        let lm_head = if !cfg.tie_word_embeddings {
392            ReplicatedLayer::new(
393                cfg.hidden_size,
394                cfg.vocab_size,
395                &cfg.quantization_config,
396                false,
397                mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
398            )?
399        } else {
400            ReplicatedLayer::from_linear(candle_nn::Linear::new(
401                mapper.cast_nm_device(
402                    embed_tokens.embeddings(),
403                    normal_loading_metadata.loading_isq,
404                )?,
405                None,
406            ))?
407        };
408        Ok(Self {
409            embed_tokens,
410            norm,
411            layers,
412            lm_head,
413            cache: EitherCache::Normal(NormalCache::new(
414                cfg.num_hidden_layers,
415                cfg.max_position_embeddings,
416            )),
417            max_seq_len: cfg.max_position_embeddings,
418            cfg: ModelConfigMetadata {
419                max_seq_len: cfg.max_position_embeddings,
420                num_layers: cfg.num_hidden_layers,
421                hidden_size: cfg.hidden_size,
422                num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
423                num_kv_heads: (cfg.num_key_value_heads / mapper.get_comm_for(0)?.world_size())
424                    .max(1),
425                sliding_window: cfg.sliding_window,
426                k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
427                v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
428            },
429            device: normal_loading_metadata.real_device.clone(),
430            dtype: vb.dtype(),
431            mapper,
432        })
433    }
434
435    pub fn embed_tokens(&self, input_ids: &Tensor) -> Result<Tensor> {
436        self.embed_tokens.forward(input_ids)
437    }
438
439    pub fn forward_embeds(
440        &self,
441        mut xs: Tensor,
442        attention_mask: Option<&Tensor>,
443        position_ids: &Tensor,
444        context_lens: Vec<(usize, usize)>,
445        flash_params: &FlashParams,
446    ) -> Result<Tensor> {
447        let cache = &mut self.cache.normal().0;
448        let cos_sin = self.layers[0]
449            .self_attn
450            .rotary_emb
451            .compute_cos_sin(position_ids, xs.dtype())?;
452
453        for (i, layer) in self.layers.iter().enumerate() {
454            xs = self.mapper.map(xs, i)?;
455            xs = layer.forward(
456                &xs,
457                attention_mask
458                    .as_ref()
459                    .map(|m| m.to_device(xs.device()).unwrap())
460                    .as_ref(),
461                &cos_sin,
462                &mut cache[i],
463                flash_params,
464            )?
465        }
466        let xs = xs.to_device(&self.device)?;
467        let mut xs = xs.apply(&self.norm)?;
468        if let Some(t) = self.lm_head.quantized_act_type() {
469            xs = xs.to_dtype(t)?;
470        }
471        extract_logits(&self.lm_head.forward(&xs)?, context_lens)
472    }
473}
474
475impl IsqModel for Qwen2VLTextModel {
476    fn get_layers(
477        &mut self,
478    ) -> (
479        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
480        &dyn DeviceMapper,
481    ) {
482        let mut tensors = Vec::new();
483        tensors.push((&mut self.lm_head, None));
484        for (i, layer) in self.layers.iter_mut().enumerate() {
485            tensors.push((&mut layer.self_attn.q_proj, Some(i)));
486            tensors.push((&mut layer.self_attn.k_proj, Some(i)));
487            tensors.push((&mut layer.self_attn.v_proj, Some(i)));
488            tensors.push((&mut layer.self_attn.o_proj, Some(i)));
489            tensors.push((&mut layer.mlp.gate_proj, Some(i)));
490            tensors.push((&mut layer.mlp.up_proj, Some(i)));
491            tensors.push((&mut layer.mlp.down_proj, Some(i)));
492        }
493        (tensors, &*self.mapper)
494    }
495
496    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
497        let uvb = UnVarBuilder::new();
498
499        let uvb_m = uvb.pp("model");
500        uvb_m.pp("embed_tokens").add(&self.embed_tokens);
501        uvb_m.pp("norm").add(&self.norm);
502
503        for (layer_idx, layer) in self.layers.iter().enumerate() {
504            let uvb_l = uvb_m.pp("layers").pp(layer_idx);
505            uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
506            uvb_l
507                .pp("post_attention_layernorm")
508                .add(&layer.post_attention_layernorm);
509        }
510
511        uvb.to_safetensors()
512    }
513}