mistralrs_core/vision_models/qwen2_5_vl/
text.rs

1use std::{collections::HashMap, sync::Arc};
2
3use candle_core::{DType, Device, Result, Tensor};
4use candle_nn::{Embedding, Module};
5use mistralrs_quant::{
6    ColumnParallelLayer, QuantMethod, ReplicatedLayer, RowParallelLayer, ShardedVarBuilder,
7};
8
9use super::config::Config;
10use crate::{
11    attention::SdpaParams,
12    device_map::DeviceMapper,
13    layers::{self, Activation, F32RmsNorm, Qwen2_5VLRotaryEmbedding, Sdpa},
14    paged_attention::{AttentionImplementation, ModelConfigMetadata},
15    pipeline::{
16        extract_logits, text_models_inputs_processor::FlashParams, EitherCache, IsqModel, KvCache,
17        NormalCache, NormalLoadingMetadata,
18    },
19    utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
20};
21
22struct Mlp {
23    gate_proj: Arc<dyn QuantMethod>,
24    up_proj: Arc<dyn QuantMethod>,
25    down_proj: Arc<dyn QuantMethod>,
26    act_fn: Activation,
27}
28
29impl Mlp {
30    fn new(cfg: &Config, vb: ShardedVarBuilder, comm: &Arc<mistralrs_quant::Comm>) -> Result<Self> {
31        let hidden_sz = cfg.hidden_size;
32        let intermediate_sz = cfg.intermediate_size;
33        let gate_proj = ColumnParallelLayer::new(
34            hidden_sz,
35            intermediate_sz,
36            &cfg.quantization_config,
37            false,
38            comm,
39            vb.pp("gate_proj"),
40        )?;
41        let up_proj = ColumnParallelLayer::new(
42            hidden_sz,
43            intermediate_sz,
44            &cfg.quantization_config,
45            false,
46            comm,
47            vb.pp("up_proj"),
48        )?;
49        let down_proj = RowParallelLayer::new(
50            intermediate_sz,
51            hidden_sz,
52            &cfg.quantization_config,
53            false,
54            comm,
55            vb.pp("down_proj"),
56        )?;
57        Ok(Self {
58            gate_proj,
59            up_proj,
60            down_proj,
61            act_fn: cfg.hidden_act,
62        })
63    }
64
65    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
66        let original_dtype = xs.dtype();
67        let mut xs = xs.clone();
68        if let Some(t) = self.gate_proj.quantized_act_type() {
69            xs = xs.to_dtype(t)?;
70        }
71        let lhs = self.gate_proj.forward(&xs)?.apply(&self.act_fn)?;
72        let rhs = self.up_proj.forward(&xs)?;
73        self.down_proj
74            .forward(&(lhs * rhs)?)?
75            .to_dtype(original_dtype)
76    }
77}
78
79struct Attention {
80    q_proj: Arc<dyn QuantMethod>,
81    k_proj: Arc<dyn QuantMethod>,
82    v_proj: Arc<dyn QuantMethod>,
83    o_proj: Arc<dyn QuantMethod>,
84    num_heads: usize,
85    num_kv_heads: usize,
86    head_dim: usize,
87    rotary_emb: Arc<Qwen2_5VLRotaryEmbedding>,
88    sdpa_params: SdpaParams,
89}
90
91impl Attention {
92    fn new(
93        rotary_emb: Arc<Qwen2_5VLRotaryEmbedding>,
94        cfg: &Config,
95        vb: ShardedVarBuilder,
96        comm: &Arc<mistralrs_quant::Comm>,
97    ) -> Result<Self> {
98        let hidden_sz = cfg.hidden_size;
99        let num_heads = cfg.num_attention_heads;
100        let num_kv_heads = cfg.num_key_value_heads;
101        let head_dim = hidden_sz / num_heads;
102        let q_proj = ColumnParallelLayer::new(
103            hidden_sz,
104            num_heads * head_dim,
105            &cfg.quantization_config,
106            true,
107            comm,
108            vb.pp("q_proj"),
109        )?;
110        let kv_shard = mistralrs_quant::compute_kv_shard(
111            cfg.num_key_value_heads,
112            cfg.hidden_size / cfg.num_attention_heads,
113            comm,
114        );
115        let k_proj = ColumnParallelLayer::new_with_shard(
116            hidden_sz,
117            num_kv_heads * head_dim,
118            &cfg.quantization_config,
119            true,
120            comm,
121            kv_shard,
122            vb.pp("k_proj"),
123        )?;
124        let v_proj = ColumnParallelLayer::new_with_shard(
125            hidden_sz,
126            num_kv_heads * head_dim,
127            &cfg.quantization_config,
128            true,
129            comm,
130            kv_shard,
131            vb.pp("v_proj"),
132        )?;
133        let o_proj = RowParallelLayer::new(
134            num_heads * head_dim,
135            hidden_sz,
136            &cfg.quantization_config,
137            false,
138            comm,
139            vb.pp("o_proj"),
140        )?;
141        Ok(Self {
142            q_proj,
143            k_proj,
144            v_proj,
145            o_proj,
146            num_heads: num_heads / comm.world_size(),
147            num_kv_heads: (num_kv_heads / comm.world_size()).max(1),
148            head_dim,
149            rotary_emb,
150            sdpa_params: SdpaParams {
151                n_kv_groups: mistralrs_quant::compute_n_kv_groups(
152                    cfg.num_key_value_heads,
153                    cfg.num_attention_heads,
154                    comm,
155                ),
156                use_flash_attn: false,
157                softcap: None,
158                softmax_scale: 1.0 / (head_dim as f32).sqrt(),
159                sliding_window: None,
160            },
161        })
162    }
163
164    #[allow(clippy::too_many_arguments)]
165    fn forward(
166        &self,
167        xs: &Tensor,
168        attention_mask: Option<&Tensor>,
169        cos_sin: &(Tensor, Tensor),
170        kv_cache: &mut KvCache,
171        flash_params: &FlashParams,
172    ) -> Result<Tensor> {
173        let (b_sz, q_len, _) = xs.dims3()?;
174
175        let original_dtype = xs.dtype();
176        let mut xs = xs.clone();
177        if let Some(t) = self.q_proj.quantized_act_type() {
178            xs = xs.to_dtype(t)?;
179        }
180        let mut q = self.q_proj.forward(&xs)?;
181        let mut k = self.k_proj.forward(&xs)?;
182        let mut v = self.v_proj.forward(&xs)?;
183        if self.q_proj.quantized_act_type().is_some() {
184            q = q.to_dtype(original_dtype)?;
185            k = k.to_dtype(original_dtype)?;
186            v = v.to_dtype(original_dtype)?;
187        }
188
189        let (mut q, mut k, v) = if q_len != 1 {
190            let q = q
191                .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
192                .transpose(1, 2)?;
193            let k = k
194                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
195                .transpose(1, 2)?;
196            let v = v
197                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
198                .transpose(1, 2)?;
199            (q, k, v)
200        } else {
201            let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
202            let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
203            let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
204            (q, k, v)
205        };
206
207        // we share cos_sin amount all layers, but when some layers are in different device from layers[0],
208        // need to move cos_sin to corresponding device
209        let cos_sin_relocated = &(
210            cos_sin.0.to_device(q.device())?,
211            cos_sin.1.to_device(q.device())?,
212        );
213        self.rotary_emb.forward(cos_sin_relocated, &mut q, &mut k)?;
214
215        let mut attn_output = {
216            let (k, v) = kv_cache.append(&k, &v)?;
217
218            Sdpa.run_attention(
219                &q.contiguous()?.to_dtype(DType::F32)?,
220                &k.contiguous()?.to_dtype(DType::F32)?,
221                &v.contiguous()?.to_dtype(DType::F32)?,
222                attention_mask
223                    .map(|mask| mask.to_dtype(DType::F32).unwrap())
224                    .as_ref(),
225                Some(flash_params),
226                &self.sdpa_params,
227            )?
228            .to_dtype(q.dtype())?
229        };
230
231        if let Some(t) = self.q_proj.quantized_act_type() {
232            attn_output = attn_output.to_dtype(t)?;
233        }
234        attn_output = if attention_mask.is_some() {
235            attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?
236        } else {
237            attn_output.reshape((b_sz, q_len, ()))?
238        };
239        let mut res = self.o_proj.forward(&attn_output)?;
240        if self.q_proj.quantized_act_type().is_some() {
241            res = res.to_dtype(original_dtype)?;
242        }
243        Ok(res)
244    }
245}
246
247pub struct DecoderLayer {
248    self_attn: Attention,
249    mlp: Mlp,
250    input_layernorm: F32RmsNorm,
251    post_attention_layernorm: F32RmsNorm,
252}
253
254impl DecoderLayer {
255    fn new(
256        rotary_emb: Arc<Qwen2_5VLRotaryEmbedding>,
257        cfg: &Config,
258        vb: ShardedVarBuilder,
259        mapper: &dyn DeviceMapper,
260        layer_idx: usize,
261        loading_isq: bool,
262        comm: &Arc<mistralrs_quant::Comm>,
263    ) -> Result<Self> {
264        let self_attn = Attention::new(
265            rotary_emb,
266            cfg,
267            mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
268            comm,
269        )?;
270        let mlp = Mlp::new(
271            cfg,
272            mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq),
273            comm,
274        )?;
275        let input_layernorm = F32RmsNorm::new(
276            cfg.hidden_size,
277            cfg.rms_norm_eps,
278            mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
279        )?;
280        let post_attention_layernorm = F32RmsNorm::new(
281            cfg.hidden_size,
282            cfg.rms_norm_eps,
283            mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
284        )?;
285        Ok(Self {
286            self_attn,
287            mlp,
288            input_layernorm,
289            post_attention_layernorm,
290        })
291    }
292
293    #[allow(clippy::too_many_arguments)]
294    fn forward(
295        &self,
296        xs: &Tensor,
297        attention_mask: Option<&Tensor>,
298        cos_sin: &(Tensor, Tensor),
299        kv_cache: &mut KvCache,
300        flash_params: &FlashParams,
301    ) -> Result<Tensor> {
302        let residual = xs;
303        let xs = self.input_layernorm.forward(xs)?;
304        let xs = self
305            .self_attn
306            .forward(&xs, attention_mask, cos_sin, kv_cache, flash_params)?;
307        let xs = (xs + residual)?;
308        let residual = &xs;
309        let xs = self
310            .mlp
311            .forward(&xs.apply(&self.post_attention_layernorm)?)?;
312        residual + xs
313    }
314}
315
316pub struct Qwen2_5VLTextModel {
317    embed_tokens: Embedding,
318    pub(super) norm: F32RmsNorm,
319    layers: Vec<DecoderLayer>,
320    mapper: Box<dyn DeviceMapper + Send + Sync>,
321    lm_head: Arc<dyn QuantMethod>,
322    pub(super) cache: EitherCache,
323    pub(super) cfg: ModelConfigMetadata,
324    pub(super) device: Device,
325    pub(super) dtype: DType,
326    pub(super) max_seq_len: usize,
327}
328
329impl Qwen2_5VLTextModel {
330    pub fn new(
331        cfg: &Config,
332        vb: ShardedVarBuilder,
333        _is_gptx: bool,
334        normal_loading_metadata: NormalLoadingMetadata,
335        attention_mechanism: AttentionImplementation,
336    ) -> Result<Self> {
337        if !matches!(attention_mechanism, AttentionImplementation::Eager) {
338            candle_core::bail!("Expected eager attention implementation");
339        }
340        let mapper = normal_loading_metadata.mapper;
341        let vb_m = vb.pp("model");
342
343        let embed_tokens = layers::embedding(
344            cfg.vocab_size,
345            cfg.hidden_size,
346            mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
347        )?;
348        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
349        let head_dim = cfg.hidden_size / cfg.num_attention_heads;
350
351        let mut ropes = HashMap::new();
352        for layer_idx in 0..cfg.num_hidden_layers {
353            let device = mapper
354                .device_for(layer_idx, false)
355                .unwrap_or(&normal_loading_metadata.real_device);
356            ropes.insert(
357                device.location(),
358                Arc::new(Qwen2_5VLRotaryEmbedding::new(
359                    cfg.rope_theta as f32,
360                    head_dim,
361                    device,
362                    cfg.rope_scaling.mrope_section.clone(),
363                )?),
364            );
365        }
366
367        let vb_l = vb_m.pp("layers");
368        for layer_idx in NiceProgressBar::<_, 'b'>(
369            0..cfg.num_hidden_layers,
370            "Loading repeating layers",
371            &normal_loading_metadata.multi_progress,
372        ) {
373            let device = mapper
374                .device_for(layer_idx, false)
375                .unwrap_or(&normal_loading_metadata.real_device);
376            let rotary_emb = ropes
377                .get(&device.location())
378                .expect("No RoPE for device location!")
379                .clone();
380            let comm = mapper.get_comm_for(layer_idx)?;
381            let layer = DecoderLayer::new(
382                rotary_emb.clone(),
383                cfg,
384                vb_l.pp(layer_idx),
385                &*mapper,
386                layer_idx,
387                normal_loading_metadata.loading_isq,
388                &comm,
389            )?;
390            layers.push(layer)
391        }
392        let norm = F32RmsNorm::new(
393            cfg.hidden_size,
394            cfg.rms_norm_eps,
395            mapper.set_nm_device(vb_m.pp("norm"), false),
396        )?;
397        let lm_head = if !cfg.tie_word_embeddings {
398            ReplicatedLayer::new(
399                cfg.hidden_size,
400                cfg.vocab_size,
401                &None,
402                false,
403                mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
404            )?
405        } else {
406            ReplicatedLayer::from_linear(candle_nn::Linear::new(
407                mapper.cast_nm_device(
408                    embed_tokens.embeddings(),
409                    normal_loading_metadata.loading_isq,
410                )?,
411                None,
412            ))?
413        };
414        Ok(Self {
415            embed_tokens,
416            norm,
417            layers,
418            lm_head,
419            cache: EitherCache::Normal(NormalCache::new(
420                cfg.num_hidden_layers,
421                cfg.max_position_embeddings,
422            )),
423            max_seq_len: cfg.max_position_embeddings,
424            cfg: ModelConfigMetadata {
425                max_seq_len: cfg.max_position_embeddings,
426                num_layers: cfg.num_hidden_layers,
427                hidden_size: cfg.hidden_size,
428                num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
429                num_kv_heads: (cfg.num_key_value_heads / mapper.get_comm_for(0)?.world_size())
430                    .max(1),
431                sliding_window: cfg.sliding_window,
432                k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
433                v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
434            },
435            device: normal_loading_metadata.real_device.clone(),
436            dtype: vb.dtype(),
437            mapper,
438        })
439    }
440
441    pub fn embed_tokens(&self, input_ids: &Tensor) -> Result<Tensor> {
442        self.embed_tokens.forward(input_ids)
443    }
444
445    pub fn forward_embeds(
446        &self,
447        mut xs: Tensor,
448        attention_mask: Option<&Tensor>,
449        position_ids: &Tensor,
450        context_lens: Vec<(usize, usize)>,
451        flash_params: &FlashParams,
452    ) -> Result<Tensor> {
453        let cache = &mut self.cache.normal().0;
454        let cos_sin = self.layers[0]
455            .self_attn
456            .rotary_emb
457            .compute_cos_sin(position_ids, xs.dtype())?;
458
459        for (i, layer) in self.layers.iter().enumerate() {
460            xs = self.mapper.map(xs, i)?;
461            xs = layer.forward(
462                &xs,
463                attention_mask
464                    .as_ref()
465                    .map(|m| m.to_device(xs.device()).unwrap())
466                    .as_ref(),
467                &cos_sin,
468                &mut cache[i],
469                flash_params,
470            )?
471        }
472        let xs = xs.to_device(&self.device)?;
473        let mut xs = xs.apply(&self.norm)?;
474        if let Some(t) = self.lm_head.quantized_act_type() {
475            xs = xs.to_dtype(t)?;
476        }
477        extract_logits(&self.lm_head.forward(&xs)?, context_lens)
478    }
479}
480
481impl IsqModel for Qwen2_5VLTextModel {
482    fn get_layers(
483        &mut self,
484    ) -> (
485        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
486        &dyn DeviceMapper,
487    ) {
488        let mut tensors = Vec::new();
489        tensors.push((&mut self.lm_head, None));
490        for (i, layer) in self.layers.iter_mut().enumerate() {
491            tensors.push((&mut layer.self_attn.q_proj, Some(i)));
492            tensors.push((&mut layer.self_attn.k_proj, Some(i)));
493            tensors.push((&mut layer.self_attn.v_proj, Some(i)));
494            tensors.push((&mut layer.self_attn.o_proj, Some(i)));
495            tensors.push((&mut layer.mlp.gate_proj, Some(i)));
496            tensors.push((&mut layer.mlp.up_proj, Some(i)));
497            tensors.push((&mut layer.mlp.down_proj, Some(i)));
498        }
499        (tensors, &*self.mapper)
500    }
501
502    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
503        let uvb = UnVarBuilder::new();
504
505        let uvb_m = uvb.pp("model");
506        uvb_m.pp("embed_tokens").add(&self.embed_tokens);
507        uvb_m.pp("norm").add(&self.norm);
508
509        for (layer_idx, layer) in self.layers.iter().enumerate() {
510            let uvb_l = uvb_m.pp("layers").pp(layer_idx);
511            uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
512            uvb_l
513                .pp("post_attention_layernorm")
514                .add(&layer.post_attention_layernorm);
515        }
516
517        uvb.to_safetensors()
518    }
519}