mistralrs_core/models/
quantized_qwen2.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use candle_core::{DType, Device, Result, Tensor};
7use candle_nn::{Embedding, Module};
8use indicatif::MultiProgress;
9use mistralrs_quant::{GgufMatMul, QuantMethod, QuantMethodConfig};
10
11use crate::attention::SdpaParams;
12use crate::device_map::DeviceMapper;
13use crate::gguf::Content;
14use crate::layers::{CausalMasker, MatMul, QRmsNorm, RotaryEmbedding, Sdpa};
15use crate::layers_masker::PastKvLenCache;
16use crate::paged_attention::{AttentionImplementation, PagedAttention};
17use crate::pipeline::text_models_inputs_processor::PagedAttentionInputMetadata;
18use crate::pipeline::{extract_logits, EitherCache, KvCache, NormalCache};
19use crate::utils::gguf_metadata::ContentMetadata;
20use crate::utils::model_config as ModelConfig;
21use crate::utils::progress::NiceProgressBar;
22const MAX_SEQ_LEN: u32 = 4096;
23
24struct Mlp {
25    feed_forward_w1: Arc<dyn QuantMethod>,
26    feed_forward_w2: Arc<dyn QuantMethod>,
27    feed_forward_w3: Arc<dyn QuantMethod>,
28}
29
30impl Mlp {
31    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
32        let w1 = MatMul.qmethod_matmul(xs, &*self.feed_forward_w1)?;
33        let w3 = MatMul.qmethod_matmul(xs, &*self.feed_forward_w3)?;
34        let y = &(candle_nn::ops::silu(&w1)? * w3)?;
35        MatMul.qmethod_matmul(y, &*self.feed_forward_w2)
36    }
37}
38
39struct LayerWeights {
40    attention_wq: Arc<dyn QuantMethod>,
41    attention_wk: Arc<dyn QuantMethod>,
42    attention_wv: Arc<dyn QuantMethod>,
43    attention_wo: Arc<dyn QuantMethod>,
44    attention_norm: QRmsNorm,
45    mlp: Mlp,
46    ffn_norm: QRmsNorm,
47    n_head: usize,
48    n_kv_head: usize,
49    head_dim: usize,
50    rotary: Arc<RotaryEmbedding>,
51    paged_attn: Option<PagedAttention>,
52    sdpa_params: SdpaParams,
53    dtype: DType,
54}
55
56impl LayerWeights {
57    fn forward_attn(
58        &self,
59        x: &Tensor,
60        mask: Option<&Tensor>,
61        start_offsets: &[usize],
62        kv_cache: &mut KvCache,
63        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
64    ) -> Result<Tensor> {
65        let (b_sz, seq_len, _) = x.dims3()?;
66
67        let q = MatMul
68            .qmethod_matmul(x, &*self.attention_wq)?
69            .to_dtype(self.dtype)?;
70        let k = MatMul
71            .qmethod_matmul(x, &*self.attention_wk)?
72            .to_dtype(self.dtype)?;
73        let v = MatMul
74            .qmethod_matmul(x, &*self.attention_wv)?
75            .to_dtype(self.dtype)?;
76
77        let (q, k, v) = if seq_len != 1 {
78            let q = q
79                .reshape((b_sz, seq_len, self.n_head, self.head_dim))?
80                .transpose(1, 2)?;
81            let k = k
82                .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
83                .transpose(1, 2)?;
84            let v = v
85                .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
86                .transpose(1, 2)?;
87            (q, k, v)
88        } else {
89            let q = q.reshape((b_sz, self.n_head, seq_len, self.head_dim))?;
90            let k = k.reshape((b_sz, self.n_kv_head, seq_len, self.head_dim))?;
91            let v = v.reshape((b_sz, self.n_kv_head, seq_len, self.head_dim))?;
92            (q, k, v)
93        };
94
95        let (q, k) = self.rotary.forward(&q, &k, start_offsets)?;
96
97        let y = match &self.paged_attn {
98            Some(paged_attn) => {
99                let ((key_cache, value_cache), input_metadata) = metadata.unwrap();
100                paged_attn.forward(
101                    &q,
102                    &k,
103                    &v,
104                    mask,
105                    Some(key_cache),
106                    Some(value_cache),
107                    input_metadata,
108                    &self.sdpa_params,
109                    None,
110                )?
111            }
112            None => {
113                let (k, v) = kv_cache.append(&k, &v)?;
114
115                Sdpa.run_attention(&q, &k, &v, mask, None, &self.sdpa_params)?
116            }
117        };
118
119        let y = if mask.is_some() {
120            y.transpose(1, 2)?.reshape((b_sz, seq_len, ()))?
121        } else {
122            y.reshape((b_sz, seq_len, ()))?
123        };
124
125        let y = MatMul.qmethod_matmul(&y.to_dtype(x.dtype())?, &*self.attention_wo)?;
126        Ok(y)
127    }
128}
129
130pub struct ModelWeights {
131    tok_embeddings: Embedding,
132    layers: Vec<LayerWeights>,
133    norm: QRmsNorm,
134    output: Arc<dyn QuantMethod>,
135    pub device: Device,
136    pub cache: EitherCache,
137    pub max_seq_len: usize,
138    mapper: Option<Box<dyn DeviceMapper + Send + Sync>>,
139    dtype: DType,
140}
141
142// qwen2 `llm` fields:
143// https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#llm
144// NOTE: Types here do not match spec
145pub(crate) struct PropsGGUF {
146    pub head_count: usize,
147    pub head_count_kv: usize,
148    pub block_count: usize,
149    pub embedding_length: usize,
150    pub rms_norm_eps: f32,
151    pub max_seq_len: usize,
152    pub rope_freq_base: f32,
153    pub key_length: usize,
154    pub value_length: usize,
155}
156
157impl TryFrom<ContentMetadata<'_>> for PropsGGUF {
158    type Error = anyhow::Error;
159
160    fn try_from(c: ContentMetadata) -> std::result::Result<Self, Self::Error> {
161        c.verify_arch("qwen2")?;
162
163        let required = [
164            "attention.head_count",
165            "attention.head_count_kv",
166            "block_count",
167            "embedding_length",
168            "attention.layer_norm_rms_epsilon",
169        ];
170        c.has_required_keys(&required)?;
171
172        let embed_len = c.get_value::<u32>("embedding_length")? as usize;
173        let head_count = c.get_value::<u32>("attention.head_count")? as usize;
174
175        // NOTE: Values are not aligned with GGUFv3 types
176        // TODO: Normalize value types to spec
177        let props = Self {
178            head_count,
179            head_count_kv: c.get_value::<u32>("attention.head_count_kv")? as usize,
180            block_count: c.get_value::<u32>("block_count")? as usize,
181            embedding_length: embed_len,
182            // Strangely this value is generally 1e-6 in GGUF file but used to be 1e-5 by default.
183            rms_norm_eps: c.get_value("attention.layer_norm_rms_epsilon")?,
184            max_seq_len: c
185                .get_value::<u64>("context_length")
186                .ok()
187                .unwrap_or(MAX_SEQ_LEN as u64) as usize,
188            rope_freq_base: c.get_value("rope.freq_base").ok().unwrap_or(10_000_f32),
189            key_length: c
190                .get_value::<u32>("attention.key_length")
191                .ok()
192                .map(|x| x as usize)
193                .unwrap_or(embed_len / head_count),
194            value_length: c
195                .get_value::<u32>("attention.value_length")
196                .ok()
197                .map(|x| x as usize)
198                .unwrap_or(embed_len / head_count),
199        };
200
201        Ok(props)
202    }
203}
204
205impl ModelConfig::FromGGUF for ModelWeights {
206    fn from_gguf<R: std::io::Seek + std::io::Read>(
207        mut ct: Content<'_, R>,
208        device: &Device,
209        mapper: Box<dyn DeviceMapper + Send + Sync>,
210        attention_mechanism: AttentionImplementation,
211        dtype: DType,
212    ) -> Result<Self> {
213        // Parameter extraction from metadata.
214        let metadata = ContentMetadata {
215            path_prefix: "qwen2",
216            metadata: ct.get_metadata(),
217        };
218        let PropsGGUF {
219            head_count,
220            head_count_kv,
221            block_count,
222            embedding_length,
223            rms_norm_eps,
224            max_seq_len,
225            rope_freq_base,
226            key_length,
227            value_length,
228        } = PropsGGUF::try_from(metadata).or_else(|err| candle_core::bail!("{err}"))?;
229
230        let qtok_embeddings = ct.tensor("token_embd.weight", device)?;
231        let tok_embeddings = qtok_embeddings.dequantize(device)?;
232        let norm = QRmsNorm::new(ct.tensor("output_norm.weight", device)?, rms_norm_eps)?;
233        let output = if !ct.has_tensor("output.weight") {
234            ct.tensor("token_embd.weight", device)?
235        } else {
236            ct.tensor("output.weight", device)?
237        };
238        let mut layers = Vec::with_capacity(block_count);
239
240        let head_dim = key_length;
241        if key_length != value_length {
242            candle_core::bail!(
243                "Expected key_length == value_length, got {key_length} != {value_length}"
244            );
245        }
246
247        let mut ropes = HashMap::new();
248        for layer_idx in 0..block_count {
249            let device = mapper.device_for(layer_idx, false).unwrap_or(device);
250            ropes.insert(
251                device.location(),
252                Arc::new(RotaryEmbedding::new(
253                    rope_freq_base,
254                    head_dim,
255                    max_seq_len,
256                    device,
257                    true,
258                    dtype,
259                )?),
260            );
261        }
262
263        for layer_idx in NiceProgressBar::<_, 'b'>(
264            0..block_count,
265            "Loading repeating layers",
266            &MultiProgress::new(),
267        ) {
268            let prefix = format!("blk.{layer_idx}");
269            let device = mapper.device_for(layer_idx, false).unwrap_or(device);
270            let rotary = ropes
271                .get(&device.location())
272                .expect("No RoPE for device location!")
273                .clone();
274
275            let attention_wq = ct.tensor(&format!("{prefix}.attn_q.weight"), device)?;
276            let attention_bias_q = ct
277                .tensor(&format!("{prefix}.attn_q.bias"), device)?
278                .dequantize(device)?;
279            let attention_wk = ct.tensor(&format!("{prefix}.attn_k.weight"), device)?;
280            let attention_bias_k = ct
281                .tensor(&format!("{prefix}.attn_k.bias"), device)?
282                .dequantize(device)?;
283            let attention_wv = ct.tensor(&format!("{prefix}.attn_v.weight"), device)?;
284            let attention_bias_v = ct
285                .tensor(&format!("{prefix}.attn_v.bias"), device)?
286                .dequantize(device)?;
287            let attention_wo = ct.tensor(&format!("{prefix}.attn_output.weight"), device)?;
288
289            let feed_forward_w1 = ct.tensor(&format!("{prefix}.ffn_gate.weight"), device)?;
290            let feed_forward_w2 = ct.tensor(&format!("{prefix}.ffn_down.weight"), device)?;
291            let feed_forward_w3 = ct.tensor(&format!("{prefix}.ffn_up.weight"), device)?;
292            let mlp = Mlp {
293                feed_forward_w1: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
294                    q_weight: Arc::new(feed_forward_w1),
295                    b: None,
296                })?),
297                feed_forward_w2: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
298                    q_weight: Arc::new(feed_forward_w2),
299                    b: None,
300                })?),
301                feed_forward_w3: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
302                    q_weight: Arc::new(feed_forward_w3),
303                    b: None,
304                })?),
305            };
306
307            let attention_norm = ct.tensor(&format!("{prefix}.attn_norm.weight"), device)?;
308            let ffn_norm = ct.tensor(&format!("{prefix}.ffn_norm.weight"), device)?;
309            let paged_attn = match &attention_mechanism {
310                AttentionImplementation::Eager => None,
311                AttentionImplementation::PagedAttention => {
312                    Some(PagedAttention::new(head_dim, device, None)?)
313                }
314            };
315            layers.push(LayerWeights {
316                attention_wq: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
317                    q_weight: Arc::new(attention_wq),
318                    b: Some(attention_bias_q),
319                })?),
320                attention_wk: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
321                    q_weight: Arc::new(attention_wk),
322                    b: Some(attention_bias_k),
323                })?),
324                attention_wv: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
325                    q_weight: Arc::new(attention_wv),
326                    b: Some(attention_bias_v),
327                })?),
328                attention_wo: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
329                    q_weight: Arc::new(attention_wo),
330                    b: None,
331                })?),
332                attention_norm: QRmsNorm::new(attention_norm, rms_norm_eps)?,
333                mlp,
334                ffn_norm: QRmsNorm::new(ffn_norm, rms_norm_eps)?,
335                n_head: head_count,
336                n_kv_head: head_count_kv,
337                head_dim,
338                rotary: rotary.clone(),
339                paged_attn,
340                sdpa_params: SdpaParams {
341                    n_kv_groups: head_count / head_count_kv,
342                    use_flash_attn: false,
343                    softcap: None,
344                    softmax_scale: 1.0 / (head_dim as f32).sqrt(),
345                    sliding_window: None,
346                },
347                dtype,
348            })
349        }
350        Ok(Self {
351            tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
352            layers,
353            norm,
354            output: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
355                q_weight: Arc::new(output),
356                b: None,
357            })?),
358            device: device.clone(),
359            cache: EitherCache::Normal(NormalCache::new(block_count, max_seq_len)),
360            max_seq_len,
361            mapper: Some(mapper),
362            dtype,
363        })
364    }
365}
366
367impl ModelWeights {
368    pub fn forward(
369        &self,
370        x: &Tensor,
371        start_offsets: &[usize],
372        context_lens: Vec<(usize, usize)>,
373        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
374    ) -> Result<Tensor> {
375        let mut layer_in = self.tok_embeddings.forward(x)?;
376        let cache = &mut self.cache.normal().0;
377        let mask = CausalMasker.make_causal_mask_matrix(
378            x,
379            metadata
380                .as_ref()
381                .map(|(_, _)| &start_offsets as &dyn PastKvLenCache)
382                .unwrap_or(cache as &dyn PastKvLenCache),
383            self.dtype,
384            self.layers[0].n_head,
385        )?;
386        let mask = mask.filter(|_| {
387            metadata
388                .as_ref()
389                .map(|(_, meta)| meta.is_first_prompt_chunk)
390                .unwrap_or(true)
391        });
392        for (i, layer) in self.layers.iter().enumerate() {
393            if let Some(ref mapper) = self.mapper {
394                layer_in = mapper.map(layer_in, i)?;
395            }
396            let x = layer_in;
397            let residual = &x;
398            let x = layer.attention_norm.forward(&x)?;
399            let attn = layer.forward_attn(
400                &x,
401                mask.as_ref()
402                    .map(|m| m.to_device(x.device()).unwrap())
403                    .as_ref(),
404                start_offsets,
405                &mut cache[i],
406                metadata
407                    .as_ref()
408                    .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
409            )?;
410            let x = (attn + residual)?;
411
412            // MLP
413            let residual = &x;
414            let x = layer.ffn_norm.forward(&x)?;
415            let x = layer.mlp.forward(&x)?;
416            let x = (x + residual)?;
417            layer_in = x;
418        }
419        let x = self.norm.forward(&layer_in)?;
420        extract_logits(
421            &MatMul.qmethod_matmul(&x.contiguous()?, &*self.output)?,
422            context_lens,
423        )
424    }
425}