mistralrs_core/vision_models/qwen2_5_vl/
text.rs

1use std::{collections::HashMap, sync::Arc};
2
3use candle_core::{DType, Device, Result, Tensor};
4use candle_nn::{Embedding, Module};
5use mistralrs_quant::{
6    ColumnParallelLayer, QuantMethod, ReplicatedLayer, RowParallelLayer, ShardedVarBuilder,
7};
8
9use super::config::Config;
10use crate::{
11    attention::SdpaParams,
12    device_map::DeviceMapper,
13    layers::{self, Activation, F32RmsNorm, Qwen2_5VLRotaryEmbedding, Sdpa},
14    paged_attention::{AttentionImplementation, ModelConfigMetadata},
15    pipeline::{
16        extract_logits, text_models_inputs_processor::FlashParams, EitherCache, IsqModel, KvCache,
17        NormalCache, NormalLoadingMetadata,
18    },
19    utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
20};
21
22struct Mlp {
23    gate_proj: Arc<dyn QuantMethod>,
24    up_proj: Arc<dyn QuantMethod>,
25    down_proj: Arc<dyn QuantMethod>,
26    act_fn: Activation,
27}
28
29impl Mlp {
30    fn new(cfg: &Config, vb: ShardedVarBuilder, comm: &Arc<mistralrs_quant::Comm>) -> Result<Self> {
31        let hidden_sz = cfg.hidden_size;
32        let intermediate_sz = cfg.intermediate_size;
33        let gate_proj = ColumnParallelLayer::new(
34            hidden_sz,
35            intermediate_sz,
36            &cfg.quantization_config,
37            false,
38            comm,
39            vb.pp("gate_proj"),
40        )?;
41        let up_proj = ColumnParallelLayer::new(
42            hidden_sz,
43            intermediate_sz,
44            &cfg.quantization_config,
45            false,
46            comm,
47            vb.pp("up_proj"),
48        )?;
49        let down_proj = RowParallelLayer::new(
50            intermediate_sz,
51            hidden_sz,
52            &cfg.quantization_config,
53            false,
54            comm,
55            vb.pp("down_proj"),
56        )?;
57        Ok(Self {
58            gate_proj,
59            up_proj,
60            down_proj,
61            act_fn: cfg.hidden_act,
62        })
63    }
64
65    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
66        let original_dtype = xs.dtype();
67        let mut xs = xs.clone();
68        if let Some(t) = self.gate_proj.quantized_act_type() {
69            xs = xs.to_dtype(t)?;
70        }
71        let lhs = self.gate_proj.forward(&xs)?.apply(&self.act_fn)?;
72        let rhs = self.up_proj.forward(&xs)?;
73        self.down_proj
74            .forward(&(lhs * rhs)?)?
75            .to_dtype(original_dtype)
76    }
77}
78
79struct Attention {
80    q_proj: Arc<dyn QuantMethod>,
81    k_proj: Arc<dyn QuantMethod>,
82    v_proj: Arc<dyn QuantMethod>,
83    o_proj: Arc<dyn QuantMethod>,
84    num_heads: usize,
85    num_kv_heads: usize,
86    head_dim: usize,
87    rotary_emb: Arc<Qwen2_5VLRotaryEmbedding>,
88    sdpa_params: SdpaParams,
89}
90
91impl Attention {
92    fn new(
93        rotary_emb: Arc<Qwen2_5VLRotaryEmbedding>,
94        cfg: &Config,
95        vb: ShardedVarBuilder,
96        comm: &Arc<mistralrs_quant::Comm>,
97    ) -> Result<Self> {
98        let hidden_sz = cfg.hidden_size;
99        let num_heads = cfg.num_attention_heads;
100        let num_kv_heads = cfg.num_key_value_heads;
101        let head_dim = hidden_sz / num_heads;
102        let q_proj = ColumnParallelLayer::new(
103            hidden_sz,
104            num_heads * head_dim,
105            &cfg.quantization_config,
106            true,
107            comm,
108            vb.pp("q_proj"),
109        )?;
110        let kv_shard = mistralrs_quant::compute_kv_shard(
111            cfg.num_key_value_heads,
112            cfg.hidden_size / cfg.num_attention_heads,
113            comm,
114        );
115        let k_proj = ColumnParallelLayer::new_with_shard(
116            hidden_sz,
117            num_kv_heads * head_dim,
118            &cfg.quantization_config,
119            true,
120            comm,
121            kv_shard,
122            vb.pp("k_proj"),
123        )?;
124        let v_proj = ColumnParallelLayer::new_with_shard(
125            hidden_sz,
126            num_kv_heads * head_dim,
127            &cfg.quantization_config,
128            true,
129            comm,
130            kv_shard,
131            vb.pp("v_proj"),
132        )?;
133        let o_proj = RowParallelLayer::new(
134            num_heads * head_dim,
135            hidden_sz,
136            &cfg.quantization_config,
137            false,
138            comm,
139            vb.pp("o_proj"),
140        )?;
141        Ok(Self {
142            q_proj,
143            k_proj,
144            v_proj,
145            o_proj,
146            num_heads: num_heads / comm.world_size(),
147            num_kv_heads: (num_kv_heads / comm.world_size()).max(1),
148            head_dim,
149            rotary_emb,
150            sdpa_params: SdpaParams {
151                n_kv_groups: mistralrs_quant::compute_n_kv_groups(
152                    cfg.num_key_value_heads,
153                    cfg.num_attention_heads,
154                    comm,
155                ),
156                use_flash_attn: false,
157                softcap: None,
158                softmax_scale: 1.0 / (head_dim as f32).sqrt(),
159                sliding_window: None,
160            },
161        })
162    }
163
164    #[allow(clippy::too_many_arguments)]
165    fn forward(
166        &self,
167        xs: &Tensor,
168        attention_mask: Option<&Tensor>,
169        cos_sin: &(Tensor, Tensor),
170        kv_cache: &mut KvCache,
171        flash_params: &FlashParams,
172    ) -> Result<Tensor> {
173        let (b_sz, q_len, _) = xs.dims3()?;
174
175        let original_dtype = xs.dtype();
176        let mut xs = xs.clone();
177        if let Some(t) = self.q_proj.quantized_act_type() {
178            xs = xs.to_dtype(t)?;
179        }
180        let mut q = self.q_proj.forward(&xs)?;
181        let mut k = self.k_proj.forward(&xs)?;
182        let mut v = self.v_proj.forward(&xs)?;
183        if self.q_proj.quantized_act_type().is_some() {
184            q = q.to_dtype(original_dtype)?;
185            k = k.to_dtype(original_dtype)?;
186            v = v.to_dtype(original_dtype)?;
187        }
188
189        let (mut q, mut k, v) = if q_len != 1 {
190            let q = q
191                .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
192                .transpose(1, 2)?;
193            let k = k
194                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
195                .transpose(1, 2)?;
196            let v = v
197                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
198                .transpose(1, 2)?;
199            (q, k, v)
200        } else {
201            let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
202            let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
203            let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
204            (q, k, v)
205        };
206
207        // we share cos_sin amount all layers, but when some layers are in different device from layers[0],
208        // need to move cos_sin to corresponding device
209        let cos_sin_relocated = &(
210            cos_sin.0.to_device(q.device())?,
211            cos_sin.1.to_device(q.device())?,
212        );
213        self.rotary_emb.forward(cos_sin_relocated, &mut q, &mut k)?;
214
215        let mut attn_output = {
216            let (k, v) = kv_cache.append(&k, &v)?;
217
218            Sdpa.run_attention(
219                &q.contiguous()?.to_dtype(DType::F32)?,
220                &k.contiguous()?.to_dtype(DType::F32)?,
221                &v.contiguous()?.to_dtype(DType::F32)?,
222                attention_mask
223                    .map(|mask| mask.to_dtype(DType::F32).unwrap())
224                    .as_ref(),
225                Some(flash_params),
226                &self.sdpa_params,
227            )?
228            .to_dtype(q.dtype())?
229        };
230
231        if let Some(t) = self.q_proj.quantized_act_type() {
232            attn_output = attn_output.to_dtype(t)?;
233        }
234        attn_output = if attention_mask.is_some() {
235            attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?
236        } else {
237            attn_output.reshape((b_sz, q_len, ()))?
238        };
239        let mut res = self.o_proj.forward(&attn_output)?;
240        if self.q_proj.quantized_act_type().is_some() {
241            res = res.to_dtype(original_dtype)?;
242        }
243        Ok(res)
244    }
245}
246
247pub struct DecoderLayer {
248    self_attn: Attention,
249    mlp: Mlp,
250    input_layernorm: F32RmsNorm,
251    post_attention_layernorm: F32RmsNorm,
252}
253
254impl DecoderLayer {
255    fn new(
256        rotary_emb: Arc<Qwen2_5VLRotaryEmbedding>,
257        cfg: &Config,
258        vb: ShardedVarBuilder,
259        mapper: &dyn DeviceMapper,
260        layer_idx: usize,
261        loading_isq: bool,
262        comm: &Arc<mistralrs_quant::Comm>,
263    ) -> Result<Self> {
264        let self_attn = Attention::new(
265            rotary_emb,
266            cfg,
267            mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
268            comm,
269        )?;
270        let mlp = Mlp::new(
271            cfg,
272            mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq),
273            comm,
274        )?;
275        let input_layernorm = F32RmsNorm::new(
276            cfg.hidden_size,
277            cfg.rms_norm_eps,
278            mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
279        )?;
280        let post_attention_layernorm = F32RmsNorm::new(
281            cfg.hidden_size,
282            cfg.rms_norm_eps,
283            mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
284        )?;
285        Ok(Self {
286            self_attn,
287            mlp,
288            input_layernorm,
289            post_attention_layernorm,
290        })
291    }
292
293    #[allow(clippy::too_many_arguments)]
294    fn forward(
295        &self,
296        xs: &Tensor,
297        attention_mask: Option<&Tensor>,
298        cos_sin: &(Tensor, Tensor),
299        kv_cache: &mut KvCache,
300        flash_params: &FlashParams,
301    ) -> Result<Tensor> {
302        let residual = xs;
303        let xs = self.input_layernorm.forward(xs)?;
304        let xs = self
305            .self_attn
306            .forward(&xs, attention_mask, cos_sin, kv_cache, flash_params)?;
307        let xs = (xs + residual)?;
308        let residual = &xs;
309        let xs = self
310            .mlp
311            .forward(&xs.apply(&self.post_attention_layernorm)?)?;
312        residual + xs
313    }
314}
315
316pub struct Qwen2_5VLTextModel {
317    embed_tokens: Embedding,
318    pub(super) norm: F32RmsNorm,
319    layers: Vec<DecoderLayer>,
320    mapper: Box<dyn DeviceMapper + Send + Sync>,
321    lm_head: Arc<dyn QuantMethod>,
322    pub(super) cache: EitherCache,
323    pub(super) cfg: ModelConfigMetadata,
324    pub(super) device: Device,
325    pub(super) dtype: DType,
326    pub(super) max_seq_len: usize,
327}
328
329impl Qwen2_5VLTextModel {
330    pub fn new(
331        cfg: &Config,
332        vb: ShardedVarBuilder,
333        _is_gptx: bool,
334        normal_loading_metadata: NormalLoadingMetadata,
335        attention_mechanism: AttentionImplementation,
336    ) -> Result<Self> {
337        if !matches!(attention_mechanism, AttentionImplementation::Eager) {
338            candle_core::bail!("Expected eager attention implementation");
339        }
340        let mapper = normal_loading_metadata.mapper;
341        let vb_m = vb.pp("model");
342
343        let embed_tokens = layers::embedding(
344            cfg.vocab_size,
345            cfg.hidden_size,
346            mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
347            &cfg.quantization_config,
348        )?;
349        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
350        let head_dim = cfg.hidden_size / cfg.num_attention_heads;
351
352        let mut ropes = HashMap::new();
353        for layer_idx in 0..cfg.num_hidden_layers {
354            let device = mapper
355                .device_for(layer_idx, false)
356                .unwrap_or(&normal_loading_metadata.real_device);
357            ropes.insert(
358                device.location(),
359                Arc::new(Qwen2_5VLRotaryEmbedding::new(
360                    cfg.rope_theta as f32,
361                    head_dim,
362                    device,
363                    cfg.rope_scaling.mrope_section.clone(),
364                )?),
365            );
366        }
367
368        let vb_l = vb_m.pp("layers");
369        for layer_idx in NiceProgressBar::<_, 'b'>(
370            0..cfg.num_hidden_layers,
371            "Loading repeating layers",
372            &normal_loading_metadata.multi_progress,
373        ) {
374            let device = mapper
375                .device_for(layer_idx, false)
376                .unwrap_or(&normal_loading_metadata.real_device);
377            let rotary_emb = ropes
378                .get(&device.location())
379                .expect("No RoPE for device location!")
380                .clone();
381            let comm = mapper.get_comm_for(layer_idx)?;
382            let layer = DecoderLayer::new(
383                rotary_emb.clone(),
384                cfg,
385                vb_l.pp(layer_idx),
386                &*mapper,
387                layer_idx,
388                normal_loading_metadata.loading_isq,
389                &comm,
390            )?;
391            layers.push(layer)
392        }
393        let norm = F32RmsNorm::new(
394            cfg.hidden_size,
395            cfg.rms_norm_eps,
396            mapper.set_nm_device(vb_m.pp("norm"), false),
397        )?;
398        let lm_head = if !cfg.tie_word_embeddings {
399            ReplicatedLayer::new(
400                cfg.hidden_size,
401                cfg.vocab_size,
402                &None,
403                false,
404                mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
405            )?
406        } else {
407            ReplicatedLayer::from_linear(candle_nn::Linear::new(
408                mapper.cast_nm_device(
409                    embed_tokens.embeddings(),
410                    normal_loading_metadata.loading_isq,
411                )?,
412                None,
413            ))?
414        };
415        Ok(Self {
416            embed_tokens,
417            norm,
418            layers,
419            lm_head,
420            cache: EitherCache::Normal(NormalCache::new(
421                cfg.num_hidden_layers,
422                cfg.max_position_embeddings,
423            )),
424            max_seq_len: cfg.max_position_embeddings,
425            cfg: ModelConfigMetadata {
426                max_seq_len: cfg.max_position_embeddings,
427                num_layers: cfg.num_hidden_layers,
428                hidden_size: cfg.hidden_size,
429                num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
430                num_kv_heads: (cfg.num_key_value_heads / mapper.get_comm_for(0)?.world_size())
431                    .max(1),
432                sliding_window: cfg.sliding_window,
433                k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
434                v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
435            },
436            device: normal_loading_metadata.real_device.clone(),
437            dtype: vb.dtype(),
438            mapper,
439        })
440    }
441
442    pub fn embed_tokens(&self, input_ids: &Tensor) -> Result<Tensor> {
443        self.embed_tokens.forward(input_ids)
444    }
445
446    pub fn forward_embeds(
447        &self,
448        mut xs: Tensor,
449        attention_mask: Option<&Tensor>,
450        position_ids: &Tensor,
451        context_lens: Vec<(usize, usize)>,
452        flash_params: &FlashParams,
453    ) -> Result<Tensor> {
454        let cache = &mut self.cache.normal().0;
455        let cos_sin = self.layers[0]
456            .self_attn
457            .rotary_emb
458            .compute_cos_sin(position_ids, xs.dtype())?;
459
460        for (i, layer) in self.layers.iter().enumerate() {
461            xs = self.mapper.map(xs, i)?;
462            xs = layer.forward(
463                &xs,
464                attention_mask
465                    .as_ref()
466                    .map(|m| m.to_device(xs.device()).unwrap())
467                    .as_ref(),
468                &cos_sin,
469                &mut cache[i],
470                flash_params,
471            )?
472        }
473        let xs = xs.to_device(&self.device)?;
474        let mut xs = xs.apply(&self.norm)?;
475        if let Some(t) = self.lm_head.quantized_act_type() {
476            xs = xs.to_dtype(t)?;
477        }
478        extract_logits(&self.lm_head.forward(&xs)?, context_lens)
479    }
480}
481
482impl IsqModel for Qwen2_5VLTextModel {
483    fn get_layers(
484        &mut self,
485    ) -> (
486        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
487        &dyn DeviceMapper,
488    ) {
489        let mut tensors = Vec::new();
490        tensors.push((&mut self.lm_head, None));
491        for (i, layer) in self.layers.iter_mut().enumerate() {
492            tensors.push((&mut layer.self_attn.q_proj, Some(i)));
493            tensors.push((&mut layer.self_attn.k_proj, Some(i)));
494            tensors.push((&mut layer.self_attn.v_proj, Some(i)));
495            tensors.push((&mut layer.self_attn.o_proj, Some(i)));
496            tensors.push((&mut layer.mlp.gate_proj, Some(i)));
497            tensors.push((&mut layer.mlp.up_proj, Some(i)));
498            tensors.push((&mut layer.mlp.down_proj, Some(i)));
499        }
500        (tensors, &*self.mapper)
501    }
502
503    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
504        let uvb = UnVarBuilder::new();
505
506        let uvb_m = uvb.pp("model");
507        uvb_m.pp("embed_tokens").add(&self.embed_tokens);
508        uvb_m.pp("norm").add(&self.norm);
509
510        for (layer_idx, layer) in self.layers.iter().enumerate() {
511            let uvb_l = uvb_m.pp("layers").pp(layer_idx);
512            uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
513            uvb_l
514                .pp("post_attention_layernorm")
515                .add(&layer.post_attention_layernorm);
516        }
517
518        uvb.to_safetensors()
519    }
520}