mistralrs_core/models/
quantized_starcoder2.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use crate::attention::SdpaParams;
7use crate::device_map::DeviceMapper;
8use crate::gguf::Content;
9use crate::layers::{CausalMasker, MatMul, QLinear, RotaryEmbedding, Sdpa};
10use crate::layers_masker::PastKvLenCache;
11use crate::paged_attention::{AttentionImplementation, PagedAttention};
12use crate::pipeline::text_models_inputs_processor::PagedAttentionInputMetadata;
13use crate::pipeline::{EitherCache, KvCache, NormalCache};
14use crate::utils::gguf_metadata::ContentMetadata;
15use crate::utils::model_config as ModelConfig;
16use crate::utils::progress::NiceProgressBar;
17use candle_core::quantized::QMatMul;
18use candle_core::quantized::QTensor;
19use candle_core::{DType, Device, IndexOp, Module, Result, Tensor};
20use candle_nn::{Embedding, LayerNorm};
21use indicatif::MultiProgress;
22use mistralrs_quant::{GgufMatMul, QuantMethod, QuantMethodConfig};
23
24#[derive(Clone)]
25struct Mlp {
26    ffn_up: Arc<dyn QuantMethod>,
27    ffn_down: Arc<dyn QuantMethod>,
28}
29
30impl Module for Mlp {
31    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
32        MatMul.qmethod_matmul(
33            &MatMul
34                .qmethod_matmul(xs, &*self.ffn_up)?
35                .apply(&candle_nn::Activation::GeluPytorchTanh)?,
36            &*self.ffn_down,
37        )
38    }
39}
40
41fn layer_norm(w: QTensor, b: QTensor, eps: f64) -> Result<LayerNorm> {
42    let w = w.dequantize(&w.device())?;
43    let b = b.dequantize(&b.device())?;
44    let ln = LayerNorm::new(w, b, eps);
45    Ok(ln)
46}
47
48struct LayerWeights {
49    attn_q: Arc<dyn QuantMethod>,
50    attn_k: Arc<dyn QuantMethod>,
51    attn_v: Arc<dyn QuantMethod>,
52    attn_output: Arc<dyn QuantMethod>,
53    attn_norm: LayerNorm,
54    ffn_norm: LayerNorm,
55    mlp: Mlp,
56    n_head: usize,
57    n_kv_head: usize,
58    head_dim: usize,
59    rotary_emb: Arc<RotaryEmbedding>,
60    paged_attn: Option<PagedAttention>,
61    sdpa_params: SdpaParams,
62    dtype: DType,
63}
64
65impl LayerWeights {
66    fn forward_attn(
67        &self,
68        x: &Tensor,
69        mask: Option<&Tensor>,
70        seqlen_offsets: &[usize],
71        kv_cache: &mut KvCache,
72        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
73    ) -> Result<Tensor> {
74        let (b_sz, q_len, hidden_size) = x.dims3()?;
75
76        let q = MatMul
77            .qmethod_matmul(x, &*self.attn_q)?
78            .to_dtype(self.dtype)?;
79        let k = MatMul
80            .qmethod_matmul(x, &*self.attn_k)?
81            .to_dtype(self.dtype)?;
82        let v = MatMul
83            .qmethod_matmul(x, &*self.attn_v)?
84            .to_dtype(self.dtype)?;
85
86        let (q, k, v) = if q_len != 1 {
87            let q = q
88                .reshape((b_sz, q_len, self.n_head, self.head_dim))?
89                .transpose(1, 2)?;
90            let k = k
91                .reshape((b_sz, q_len, self.n_kv_head, self.head_dim))?
92                .transpose(1, 2)?;
93            let v = v
94                .reshape((b_sz, q_len, self.n_kv_head, self.head_dim))?
95                .transpose(1, 2)?;
96            (q, k, v)
97        } else {
98            let q = q.reshape((b_sz, self.n_head, q_len, self.head_dim))?;
99            let k = k.reshape((b_sz, self.n_kv_head, q_len, self.head_dim))?;
100            let v = v.reshape((b_sz, self.n_kv_head, q_len, self.head_dim))?;
101            (q, k, v)
102        };
103
104        let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
105
106        let y = match &self.paged_attn {
107            Some(paged_attn) => {
108                let ((key_cache, value_cache), input_metadata) = metadata.unwrap();
109                paged_attn.forward(
110                    &q,
111                    &k,
112                    &v,
113                    mask,
114                    Some(key_cache),
115                    Some(value_cache),
116                    input_metadata,
117                    &self.sdpa_params,
118                    None,
119                )?
120            }
121            None => {
122                let (k, v) = kv_cache.append(&k, &v)?;
123
124                Sdpa.run_attention(&q, &k, &v, mask, None, &self.sdpa_params)?
125            }
126        };
127
128        let y = if mask.is_some() {
129            y.transpose(1, 2)?.reshape(&[b_sz, q_len, hidden_size])?
130        } else {
131            y.reshape(&[b_sz, q_len, hidden_size])?
132        };
133
134        MatMul.qmethod_matmul(&y.to_dtype(x.dtype())?, &*self.attn_output)
135    }
136}
137
138pub struct ModelWeights {
139    tok_embeddings: Embedding,
140    layers: Vec<LayerWeights>,
141    output_norm: LayerNorm,
142    output: QMatMul,
143    mapper: Option<Box<dyn DeviceMapper + Send + Sync>>,
144    pub device: Device,
145    pub cache: EitherCache,
146    pub max_seq_len: usize,
147    dtype: DType,
148}
149
150// starcoder2 `llm` fields:
151// https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#llm
152pub(crate) struct PropsGGUF {
153    pub head_count: usize,
154    pub head_count_kv: usize,
155    pub block_count: usize,
156    pub embedding_length: usize,
157    pub layer_norm_epsilon: f64,
158    pub context_window: usize,
159    pub rope_freq_base: f32,
160}
161
162impl TryFrom<ContentMetadata<'_>> for PropsGGUF {
163    type Error = anyhow::Error;
164
165    fn try_from(c: ContentMetadata) -> std::result::Result<Self, Self::Error> {
166        c.verify_arch("starcoder2")?;
167
168        let required = [
169            "attention.head_count",
170            "attention.head_count_kv",
171            "block_count",
172            "embedding_length",
173            "attention.layer_norm_epsilon",
174            "context_length",
175        ];
176        c.has_required_keys(&required)?;
177
178        // NOTE: Values are not aligned with GGUFv3 types
179        // TODO: Normalize value types to spec
180        let props = Self {
181            head_count: c.get_value::<u32>("attention.head_count")? as usize,
182            head_count_kv: c.get_value::<u32>("attention.head_count_kv")? as usize,
183            block_count: c.get_value::<u32>("block_count")? as usize,
184            embedding_length: c.get_value::<u32>("embedding_length")? as usize,
185            layer_norm_epsilon: c.get_value::<f32>("attention.layer_norm_epsilon")? as f64,
186            context_window: c.get_value::<u32>("context_length")? as usize,
187            rope_freq_base: c.get_value("rope.freq_base").ok().unwrap_or(100_000_f32),
188        };
189
190        Ok(props)
191    }
192}
193
194impl ModelConfig::FromGGUF for ModelWeights {
195    fn from_gguf<R: std::io::Seek + std::io::Read>(
196        mut ct: Content<'_, R>,
197        device: &Device,
198        mapper: Box<dyn DeviceMapper + Send + Sync>,
199        attention_mechanism: AttentionImplementation,
200        dtype: DType,
201    ) -> Result<Self> {
202        // Parameter extraction from metadata.
203        let metadata = ContentMetadata {
204            path_prefix: "starcoder2",
205            metadata: ct.get_metadata(),
206        };
207        let PropsGGUF {
208            head_count,
209            head_count_kv,
210            block_count,
211            embedding_length,
212            layer_norm_epsilon,
213            context_window,
214            rope_freq_base,
215        } = PropsGGUF::try_from(metadata).or_else(|err| candle_core::bail!("{err}"))?;
216
217        let tok_embeddings = ct.tensor("token_embd.weight", device)?;
218        let tok_embeddings = tok_embeddings.dequantize(device)?;
219        let head_dim = embedding_length / head_count;
220        let output_norm = layer_norm(
221            ct.tensor("output_norm.weight", device)?,
222            ct.tensor("output_norm.bias", device)?,
223            layer_norm_epsilon,
224        )?;
225        let output = QMatMul::from_qtensor(ct.tensor("output.weight", device)?)?;
226        let mut layers = Vec::with_capacity(block_count);
227
228        let mut ropes = HashMap::new();
229        for layer_idx in 0..block_count {
230            let device = mapper.device_for(layer_idx, false).unwrap_or(device);
231            ropes.insert(
232                device.location(),
233                Arc::new(RotaryEmbedding::new(
234                    rope_freq_base,
235                    head_dim,
236                    context_window,
237                    device,
238                    true,
239                    dtype,
240                )?),
241            );
242        }
243
244        for layer_idx in NiceProgressBar::<_, 'b'>(
245            0..block_count,
246            "Loading repeating layers",
247            &MultiProgress::new(),
248        ) {
249            let prefix = format!("blk.{layer_idx}");
250            let device = mapper.device_for(layer_idx, false).unwrap_or(device);
251            let rotary = ropes
252                .get(&device.location())
253                .expect("No RoPE for device location!")
254                .clone();
255
256            let ffn_up = QLinear::new(&mut ct, &format!("{prefix}.ffn_up"), device)?;
257            let ffn_down = QLinear::new(&mut ct, &format!("{prefix}.ffn_down"), device)?;
258            let QMatMul::QTensor(ffn_up_w) = ffn_up.inner_ref().clone() else {
259                unreachable!()
260            };
261            let QMatMul::QTensor(ffn_down_w) = ffn_down.inner_ref().clone() else {
262                unreachable!()
263            };
264            let mlp = Mlp {
265                ffn_up: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
266                    q_weight: ffn_up_w,
267                    b: ffn_up.bias().cloned(),
268                })?),
269                ffn_down: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
270                    q_weight: ffn_down_w,
271                    b: ffn_down.bias().cloned(),
272                })?),
273            };
274            let attn_norm = layer_norm(
275                ct.tensor(&format!("{prefix}.attn_norm.weight"), device)?,
276                ct.tensor(&format!("{prefix}.attn_norm.bias"), device)?,
277                layer_norm_epsilon,
278            )?;
279            let ffn_norm = layer_norm(
280                ct.tensor(&format!("{prefix}.ffn_norm.weight"), device)?,
281                ct.tensor(&format!("{prefix}.ffn_norm.bias"), device)?,
282                layer_norm_epsilon,
283            )?;
284            let attn_q = QLinear::new(&mut ct, &format!("{prefix}.attn_q"), device)?;
285            let attn_k = QLinear::new(&mut ct, &format!("{prefix}.attn_k"), device)?;
286            let attn_v = QLinear::new(&mut ct, &format!("{prefix}.attn_v"), device)?;
287            let attn_output = QLinear::new(&mut ct, &format!("{prefix}.attn_output"), device)?;
288            let paged_attn = match &attention_mechanism {
289                AttentionImplementation::Eager => None,
290                AttentionImplementation::PagedAttention => {
291                    Some(PagedAttention::new(head_dim, device, None)?)
292                }
293            };
294            let QMatMul::QTensor(q_w) = attn_q.inner_ref().clone() else {
295                unreachable!()
296            };
297            let QMatMul::QTensor(k_w) = attn_k.inner_ref().clone() else {
298                unreachable!()
299            };
300            let QMatMul::QTensor(v_w) = attn_v.inner_ref().clone() else {
301                unreachable!()
302            };
303            let QMatMul::QTensor(o_w) = attn_output.inner_ref().clone() else {
304                unreachable!()
305            };
306            layers.push(LayerWeights {
307                attn_q: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
308                    q_weight: q_w,
309                    b: attn_q.bias().cloned(),
310                })?),
311                attn_k: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
312                    q_weight: k_w,
313                    b: attn_k.bias().cloned(),
314                })?),
315                attn_v: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
316                    q_weight: v_w,
317                    b: attn_v.bias().cloned(),
318                })?),
319                attn_output: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
320                    q_weight: o_w,
321                    b: attn_output.bias().cloned(),
322                })?),
323                attn_norm,
324                ffn_norm,
325                mlp,
326                n_head: head_count,
327                n_kv_head: head_count_kv,
328                head_dim,
329                rotary_emb: rotary,
330                paged_attn,
331                sdpa_params: SdpaParams {
332                    n_kv_groups: head_count / head_count_kv,
333                    use_flash_attn: false,
334                    softcap: None,
335                    softmax_scale: 1.0 / (head_dim as f32).sqrt(),
336                    sliding_window: None,
337                },
338                dtype,
339            })
340        }
341        Ok(Self {
342            tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
343            layers,
344            output_norm,
345            output,
346            mapper: Some(mapper),
347            device: device.clone(),
348            cache: EitherCache::Normal(NormalCache::new(block_count, context_window)),
349            max_seq_len: context_window,
350            dtype,
351        })
352    }
353}
354
355impl ModelWeights {
356    pub fn forward(
357        &self,
358        input_ids: &Tensor,
359        seqlen_offsets: &[usize],
360        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
361    ) -> Result<Tensor> {
362        let (_b_sz, seq_len) = input_ids.dims2()?;
363        let mut xs = self.tok_embeddings.forward(input_ids)?;
364        let cache = &mut self.cache.normal().0;
365        let mask = CausalMasker.make_causal_mask_matrix(
366            input_ids,
367            metadata
368                .as_ref()
369                .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
370                .unwrap_or(cache as &dyn PastKvLenCache),
371            self.dtype,
372            self.layers[0].n_head,
373        )?;
374        let mask = mask.filter(|_| {
375            metadata
376                .as_ref()
377                .map(|(_, meta)| meta.is_first_prompt_chunk)
378                .unwrap_or(true)
379        });
380        for (i, layer) in self.layers.iter().enumerate() {
381            if let Some(ref mapper) = self.mapper {
382                xs = mapper.map(xs, i)?;
383            }
384            let residual = &xs;
385            let ys = xs.apply(&layer.attn_norm)?;
386            let ys = layer.forward_attn(
387                &ys,
388                mask.as_ref()
389                    .map(|m| m.to_device(xs.device()).unwrap())
390                    .as_ref(),
391                seqlen_offsets,
392                &mut cache[i],
393                metadata
394                    .as_ref()
395                    .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
396            )?;
397            let ys = (ys + residual)?;
398            let residual = &ys;
399            let ys = ys.apply(&layer.ffn_norm)?;
400            let ys = layer.mlp.forward(&ys)?;
401            xs = (ys + residual)?
402        }
403        let xs = xs.apply(&self.output_norm)?.i((.., seq_len - 1, ..))?;
404        MatMul.qmatmul(&xs, &self.output)
405    }
406}