mistralrs_core/models/
quantized_starcoder2.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use crate::attention::SdpaParams;
7use crate::device_map::DeviceMapper;
8use crate::gguf::Content;
9use crate::layers::{CausalMasker, MatMul, QLinear, RotaryEmbedding, Sdpa};
10use crate::layers_masker::PastKvLenCache;
11use crate::paged_attention::{AttentionImplementation, PagedAttention};
12use crate::pipeline::text_models_inputs_processor::PagedAttentionInputMetadata;
13use crate::pipeline::{EitherCache, KvCache, NormalCache};
14use crate::utils::gguf_metadata::ContentMetadata;
15use crate::utils::model_config as ModelConfig;
16use crate::utils::progress::{new_multi_progress, NiceProgressBar};
17use candle_core::quantized::QMatMul;
18use candle_core::quantized::QTensor;
19use candle_core::{DType, Device, IndexOp, Module, Result, Tensor};
20use candle_nn::{Embedding, LayerNorm};
21use mistralrs_quant::{GgufMatMul, QuantMethod, QuantMethodConfig};
22
23#[derive(Clone)]
24struct Mlp {
25    ffn_up: Arc<dyn QuantMethod>,
26    ffn_down: Arc<dyn QuantMethod>,
27}
28
29impl Module for Mlp {
30    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
31        MatMul.qmethod_matmul(
32            &MatMul
33                .qmethod_matmul(xs, &*self.ffn_up)?
34                .apply(&candle_nn::Activation::GeluPytorchTanh)?,
35            &*self.ffn_down,
36        )
37    }
38}
39
40fn layer_norm(w: QTensor, b: QTensor, eps: f64) -> Result<LayerNorm> {
41    let w = w.dequantize(&w.device())?;
42    let b = b.dequantize(&b.device())?;
43    let ln = LayerNorm::new(w, b, eps);
44    Ok(ln)
45}
46
47struct LayerWeights {
48    attn_q: Arc<dyn QuantMethod>,
49    attn_k: Arc<dyn QuantMethod>,
50    attn_v: Arc<dyn QuantMethod>,
51    attn_output: Arc<dyn QuantMethod>,
52    attn_norm: LayerNorm,
53    ffn_norm: LayerNorm,
54    mlp: Mlp,
55    n_head: usize,
56    n_kv_head: usize,
57    head_dim: usize,
58    rotary_emb: Arc<RotaryEmbedding>,
59    paged_attn: Option<PagedAttention>,
60    sdpa_params: SdpaParams,
61    dtype: DType,
62}
63
64impl LayerWeights {
65    fn forward_attn(
66        &self,
67        x: &Tensor,
68        mask: Option<&Tensor>,
69        seqlen_offsets: &[usize],
70        kv_cache: &mut KvCache,
71        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
72    ) -> Result<Tensor> {
73        let (b_sz, q_len, hidden_size) = x.dims3()?;
74
75        let q = MatMul
76            .qmethod_matmul(x, &*self.attn_q)?
77            .to_dtype(self.dtype)?;
78        let k = MatMul
79            .qmethod_matmul(x, &*self.attn_k)?
80            .to_dtype(self.dtype)?;
81        let v = MatMul
82            .qmethod_matmul(x, &*self.attn_v)?
83            .to_dtype(self.dtype)?;
84
85        let (q, k, v) = if q_len != 1 {
86            let q = q
87                .reshape((b_sz, q_len, self.n_head, self.head_dim))?
88                .transpose(1, 2)?;
89            let k = k
90                .reshape((b_sz, q_len, self.n_kv_head, self.head_dim))?
91                .transpose(1, 2)?;
92            let v = v
93                .reshape((b_sz, q_len, self.n_kv_head, self.head_dim))?
94                .transpose(1, 2)?;
95            (q, k, v)
96        } else {
97            let q = q.reshape((b_sz, self.n_head, q_len, self.head_dim))?;
98            let k = k.reshape((b_sz, self.n_kv_head, q_len, self.head_dim))?;
99            let v = v.reshape((b_sz, self.n_kv_head, q_len, self.head_dim))?;
100            (q, k, v)
101        };
102
103        let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
104
105        let y = match &self.paged_attn {
106            Some(paged_attn) => {
107                let ((key_cache, value_cache), input_metadata) = metadata.unwrap();
108                paged_attn.forward(
109                    &q,
110                    &k,
111                    &v,
112                    mask,
113                    Some(key_cache),
114                    Some(value_cache),
115                    input_metadata,
116                    &self.sdpa_params,
117                    None,
118                )?
119            }
120            None => {
121                let (k, v) = kv_cache.append(&k, &v)?;
122
123                Sdpa.run_attention(&q, &k, &v, mask, None, &self.sdpa_params)?
124            }
125        };
126
127        let y = if mask.is_some() {
128            y.transpose(1, 2)?.reshape(&[b_sz, q_len, hidden_size])?
129        } else {
130            y.reshape(&[b_sz, q_len, hidden_size])?
131        };
132
133        MatMul.qmethod_matmul(&y.to_dtype(x.dtype())?, &*self.attn_output)
134    }
135}
136
137pub struct ModelWeights {
138    tok_embeddings: Embedding,
139    layers: Vec<LayerWeights>,
140    output_norm: LayerNorm,
141    output: QMatMul,
142    mapper: Option<Box<dyn DeviceMapper + Send + Sync>>,
143    pub device: Device,
144    pub cache: EitherCache,
145    pub max_seq_len: usize,
146    dtype: DType,
147}
148
149// starcoder2 `llm` fields:
150// https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#llm
151pub(crate) struct PropsGGUF {
152    pub head_count: usize,
153    pub head_count_kv: usize,
154    pub block_count: usize,
155    pub embedding_length: usize,
156    pub layer_norm_epsilon: f64,
157    pub context_window: usize,
158    pub rope_freq_base: f32,
159}
160
161impl TryFrom<ContentMetadata<'_>> for PropsGGUF {
162    type Error = anyhow::Error;
163
164    fn try_from(c: ContentMetadata) -> std::result::Result<Self, Self::Error> {
165        c.verify_arch("starcoder2")?;
166
167        let required = [
168            "attention.head_count",
169            "attention.head_count_kv",
170            "block_count",
171            "embedding_length",
172            "attention.layer_norm_epsilon",
173            "context_length",
174        ];
175        c.has_required_keys(&required)?;
176
177        // NOTE: Values are not aligned with GGUFv3 types
178        // TODO: Normalize value types to spec
179        let props = Self {
180            head_count: c.get_value::<u32>("attention.head_count")? as usize,
181            head_count_kv: c.get_value::<u32>("attention.head_count_kv")? as usize,
182            block_count: c.get_value::<u32>("block_count")? as usize,
183            embedding_length: c.get_value::<u32>("embedding_length")? as usize,
184            layer_norm_epsilon: c.get_value::<f32>("attention.layer_norm_epsilon")? as f64,
185            context_window: c.get_value::<u32>("context_length")? as usize,
186            rope_freq_base: c.get_value("rope.freq_base").ok().unwrap_or(100_000_f32),
187        };
188
189        Ok(props)
190    }
191}
192
193impl ModelConfig::FromGGUF for ModelWeights {
194    fn from_gguf<R: std::io::Seek + std::io::Read>(
195        mut ct: Content<'_, R>,
196        device: &Device,
197        mapper: Box<dyn DeviceMapper + Send + Sync>,
198        attention_mechanism: AttentionImplementation,
199        dtype: DType,
200    ) -> Result<Self> {
201        // Parameter extraction from metadata.
202        let metadata = ContentMetadata {
203            path_prefix: "starcoder2",
204            metadata: ct.get_metadata(),
205        };
206        let PropsGGUF {
207            head_count,
208            head_count_kv,
209            block_count,
210            embedding_length,
211            layer_norm_epsilon,
212            context_window,
213            rope_freq_base,
214        } = PropsGGUF::try_from(metadata).or_else(|err| candle_core::bail!("{err}"))?;
215
216        let tok_embeddings = ct.tensor("token_embd.weight", device)?;
217        let tok_embeddings = tok_embeddings.dequantize(device)?;
218        let head_dim = embedding_length / head_count;
219        let output_norm = layer_norm(
220            ct.tensor("output_norm.weight", device)?,
221            ct.tensor("output_norm.bias", device)?,
222            layer_norm_epsilon,
223        )?;
224        let output = QMatMul::from_qtensor(ct.tensor("output.weight", device)?)?;
225        let mut layers = Vec::with_capacity(block_count);
226
227        let mut ropes = HashMap::new();
228        for layer_idx in 0..block_count {
229            let device = mapper.device_for(layer_idx, false).unwrap_or(device);
230            ropes.insert(
231                device.location(),
232                Arc::new(RotaryEmbedding::new(
233                    rope_freq_base,
234                    head_dim,
235                    context_window,
236                    device,
237                    true,
238                    dtype,
239                )?),
240            );
241        }
242
243        for layer_idx in NiceProgressBar::<_, 'b'>(
244            0..block_count,
245            "Loading repeating layers",
246            &new_multi_progress(),
247        ) {
248            let prefix = format!("blk.{layer_idx}");
249            let device = mapper.device_for(layer_idx, false).unwrap_or(device);
250            let rotary = ropes
251                .get(&device.location())
252                .expect("No RoPE for device location!")
253                .clone();
254
255            let ffn_up = QLinear::new(&mut ct, &format!("{prefix}.ffn_up"), device)?;
256            let ffn_down = QLinear::new(&mut ct, &format!("{prefix}.ffn_down"), device)?;
257            let QMatMul::QTensor(ffn_up_w) = ffn_up.inner_ref().clone() else {
258                unreachable!()
259            };
260            let QMatMul::QTensor(ffn_down_w) = ffn_down.inner_ref().clone() else {
261                unreachable!()
262            };
263            let mlp = Mlp {
264                ffn_up: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
265                    q_weight: ffn_up_w,
266                    b: ffn_up.bias().cloned(),
267                })?),
268                ffn_down: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
269                    q_weight: ffn_down_w,
270                    b: ffn_down.bias().cloned(),
271                })?),
272            };
273            let attn_norm = layer_norm(
274                ct.tensor(&format!("{prefix}.attn_norm.weight"), device)?,
275                ct.tensor(&format!("{prefix}.attn_norm.bias"), device)?,
276                layer_norm_epsilon,
277            )?;
278            let ffn_norm = layer_norm(
279                ct.tensor(&format!("{prefix}.ffn_norm.weight"), device)?,
280                ct.tensor(&format!("{prefix}.ffn_norm.bias"), device)?,
281                layer_norm_epsilon,
282            )?;
283            let attn_q = QLinear::new(&mut ct, &format!("{prefix}.attn_q"), device)?;
284            let attn_k = QLinear::new(&mut ct, &format!("{prefix}.attn_k"), device)?;
285            let attn_v = QLinear::new(&mut ct, &format!("{prefix}.attn_v"), device)?;
286            let attn_output = QLinear::new(&mut ct, &format!("{prefix}.attn_output"), device)?;
287            let paged_attn = match &attention_mechanism {
288                AttentionImplementation::Eager => None,
289                AttentionImplementation::PagedAttention => {
290                    Some(PagedAttention::new(head_dim, device, None)?)
291                }
292            };
293            let QMatMul::QTensor(q_w) = attn_q.inner_ref().clone() else {
294                unreachable!()
295            };
296            let QMatMul::QTensor(k_w) = attn_k.inner_ref().clone() else {
297                unreachable!()
298            };
299            let QMatMul::QTensor(v_w) = attn_v.inner_ref().clone() else {
300                unreachable!()
301            };
302            let QMatMul::QTensor(o_w) = attn_output.inner_ref().clone() else {
303                unreachable!()
304            };
305            layers.push(LayerWeights {
306                attn_q: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
307                    q_weight: q_w,
308                    b: attn_q.bias().cloned(),
309                })?),
310                attn_k: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
311                    q_weight: k_w,
312                    b: attn_k.bias().cloned(),
313                })?),
314                attn_v: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
315                    q_weight: v_w,
316                    b: attn_v.bias().cloned(),
317                })?),
318                attn_output: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
319                    q_weight: o_w,
320                    b: attn_output.bias().cloned(),
321                })?),
322                attn_norm,
323                ffn_norm,
324                mlp,
325                n_head: head_count,
326                n_kv_head: head_count_kv,
327                head_dim,
328                rotary_emb: rotary,
329                paged_attn,
330                sdpa_params: SdpaParams {
331                    n_kv_groups: head_count / head_count_kv,
332                    softcap: None,
333                    softmax_scale: 1.0 / (head_dim as f32).sqrt(),
334                    sliding_window: None,
335                },
336                dtype,
337            })
338        }
339        Ok(Self {
340            tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
341            layers,
342            output_norm,
343            output,
344            mapper: Some(mapper),
345            device: device.clone(),
346            cache: EitherCache::Normal(NormalCache::new(block_count, context_window)),
347            max_seq_len: context_window,
348            dtype,
349        })
350    }
351}
352
353impl ModelWeights {
354    pub fn forward(
355        &self,
356        input_ids: &Tensor,
357        seqlen_offsets: &[usize],
358        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
359    ) -> Result<Tensor> {
360        let (_b_sz, seq_len) = input_ids.dims2()?;
361        let mut xs = self.tok_embeddings.forward(input_ids)?;
362        let cache = &mut self.cache.normal().0;
363        let mask = CausalMasker.make_causal_mask_matrix(
364            input_ids,
365            metadata
366                .as_ref()
367                .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
368                .unwrap_or(cache as &dyn PastKvLenCache),
369            self.dtype,
370            self.layers[0].n_head,
371        )?;
372        let mask = mask.filter(|_| {
373            metadata
374                .as_ref()
375                .map(|(_, meta)| meta.is_first_prompt_chunk)
376                .unwrap_or(true)
377        });
378        for (i, layer) in self.layers.iter().enumerate() {
379            if let Some(ref mapper) = self.mapper {
380                xs = mapper.map(xs, i)?;
381            }
382            let residual = &xs;
383            let ys = xs.apply(&layer.attn_norm)?;
384            let ys = layer.forward_attn(
385                &ys,
386                mask.as_ref()
387                    .map(|m| m.to_device(xs.device()).unwrap())
388                    .as_ref(),
389                seqlen_offsets,
390                &mut cache[i],
391                metadata
392                    .as_ref()
393                    .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
394            )?;
395            let ys = (ys + residual)?;
396            let residual = &ys;
397            let ys = ys.apply(&layer.ffn_norm)?;
398            let ys = layer.mlp.forward(&ys)?;
399            xs = (ys + residual)?
400        }
401        let xs = xs.apply(&self.output_norm)?.i((.., seq_len - 1, ..))?;
402        MatMul.qmatmul(&xs, &self.output)
403    }
404}