1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use crate::attention::SdpaParams;
7use crate::device_map::DeviceMapper;
8use crate::gguf::Content;
9use crate::layers::{CausalMasker, MatMul, QLinear, RotaryEmbedding, Sdpa};
10use crate::layers_masker::PastKvLenCache;
11use crate::paged_attention::{AttentionImplementation, PagedAttention};
12use crate::pipeline::text_models_inputs_processor::PagedAttentionInputMetadata;
13use crate::pipeline::{EitherCache, KvCache, NormalCache};
14use crate::utils::gguf_metadata::ContentMetadata;
15use crate::utils::model_config as ModelConfig;
16use crate::utils::progress::NiceProgressBar;
17use candle_core::quantized::QMatMul;
18use candle_core::quantized::QTensor;
19use candle_core::{DType, Device, IndexOp, Module, Result, Tensor};
20use candle_nn::{Embedding, LayerNorm};
21use indicatif::MultiProgress;
22use mistralrs_quant::{GgufMatMul, QuantMethod, QuantMethodConfig};
23
24#[derive(Clone)]
25struct Mlp {
26 ffn_up: Arc<dyn QuantMethod>,
27 ffn_down: Arc<dyn QuantMethod>,
28}
29
30impl Module for Mlp {
31 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
32 MatMul.qmethod_matmul(
33 &MatMul
34 .qmethod_matmul(xs, &*self.ffn_up)?
35 .apply(&candle_nn::Activation::GeluPytorchTanh)?,
36 &*self.ffn_down,
37 )
38 }
39}
40
41fn layer_norm(w: QTensor, b: QTensor, eps: f64) -> Result<LayerNorm> {
42 let w = w.dequantize(&w.device())?;
43 let b = b.dequantize(&b.device())?;
44 let ln = LayerNorm::new(w, b, eps);
45 Ok(ln)
46}
47
48struct LayerWeights {
49 attn_q: Arc<dyn QuantMethod>,
50 attn_k: Arc<dyn QuantMethod>,
51 attn_v: Arc<dyn QuantMethod>,
52 attn_output: Arc<dyn QuantMethod>,
53 attn_norm: LayerNorm,
54 ffn_norm: LayerNorm,
55 mlp: Mlp,
56 n_head: usize,
57 n_kv_head: usize,
58 head_dim: usize,
59 rotary_emb: Arc<RotaryEmbedding>,
60 paged_attn: Option<PagedAttention>,
61 sdpa_params: SdpaParams,
62 dtype: DType,
63}
64
65impl LayerWeights {
66 fn forward_attn(
67 &self,
68 x: &Tensor,
69 mask: Option<&Tensor>,
70 seqlen_offsets: &[usize],
71 kv_cache: &mut KvCache,
72 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
73 ) -> Result<Tensor> {
74 let (b_sz, q_len, hidden_size) = x.dims3()?;
75
76 let q = MatMul
77 .qmethod_matmul(x, &*self.attn_q)?
78 .to_dtype(self.dtype)?;
79 let k = MatMul
80 .qmethod_matmul(x, &*self.attn_k)?
81 .to_dtype(self.dtype)?;
82 let v = MatMul
83 .qmethod_matmul(x, &*self.attn_v)?
84 .to_dtype(self.dtype)?;
85
86 let (q, k, v) = if q_len != 1 {
87 let q = q
88 .reshape((b_sz, q_len, self.n_head, self.head_dim))?
89 .transpose(1, 2)?;
90 let k = k
91 .reshape((b_sz, q_len, self.n_kv_head, self.head_dim))?
92 .transpose(1, 2)?;
93 let v = v
94 .reshape((b_sz, q_len, self.n_kv_head, self.head_dim))?
95 .transpose(1, 2)?;
96 (q, k, v)
97 } else {
98 let q = q.reshape((b_sz, self.n_head, q_len, self.head_dim))?;
99 let k = k.reshape((b_sz, self.n_kv_head, q_len, self.head_dim))?;
100 let v = v.reshape((b_sz, self.n_kv_head, q_len, self.head_dim))?;
101 (q, k, v)
102 };
103
104 let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
105
106 let y = match &self.paged_attn {
107 Some(paged_attn) => {
108 let ((key_cache, value_cache), input_metadata) = metadata.unwrap();
109 paged_attn.forward(
110 &q,
111 &k,
112 &v,
113 mask,
114 Some(key_cache),
115 Some(value_cache),
116 input_metadata,
117 &self.sdpa_params,
118 None,
119 )?
120 }
121 None => {
122 let (k, v) = kv_cache.append(&k, &v)?;
123
124 Sdpa.run_attention(&q, &k, &v, mask, None, &self.sdpa_params)?
125 }
126 };
127
128 let y = if mask.is_some() {
129 y.transpose(1, 2)?.reshape(&[b_sz, q_len, hidden_size])?
130 } else {
131 y.reshape(&[b_sz, q_len, hidden_size])?
132 };
133
134 MatMul.qmethod_matmul(&y.to_dtype(x.dtype())?, &*self.attn_output)
135 }
136}
137
138pub struct ModelWeights {
139 tok_embeddings: Embedding,
140 layers: Vec<LayerWeights>,
141 output_norm: LayerNorm,
142 output: QMatMul,
143 mapper: Option<Box<dyn DeviceMapper + Send + Sync>>,
144 pub device: Device,
145 pub cache: EitherCache,
146 pub max_seq_len: usize,
147 dtype: DType,
148}
149
150pub(crate) struct PropsGGUF {
153 pub head_count: usize,
154 pub head_count_kv: usize,
155 pub block_count: usize,
156 pub embedding_length: usize,
157 pub layer_norm_epsilon: f64,
158 pub context_window: usize,
159 pub rope_freq_base: f32,
160}
161
162impl TryFrom<ContentMetadata<'_>> for PropsGGUF {
163 type Error = anyhow::Error;
164
165 fn try_from(c: ContentMetadata) -> std::result::Result<Self, Self::Error> {
166 c.verify_arch("starcoder2")?;
167
168 let required = [
169 "attention.head_count",
170 "attention.head_count_kv",
171 "block_count",
172 "embedding_length",
173 "attention.layer_norm_epsilon",
174 "context_length",
175 ];
176 c.has_required_keys(&required)?;
177
178 let props = Self {
181 head_count: c.get_value::<u32>("attention.head_count")? as usize,
182 head_count_kv: c.get_value::<u32>("attention.head_count_kv")? as usize,
183 block_count: c.get_value::<u32>("block_count")? as usize,
184 embedding_length: c.get_value::<u32>("embedding_length")? as usize,
185 layer_norm_epsilon: c.get_value::<f32>("attention.layer_norm_epsilon")? as f64,
186 context_window: c.get_value::<u32>("context_length")? as usize,
187 rope_freq_base: c.get_value("rope.freq_base").ok().unwrap_or(100_000_f32),
188 };
189
190 Ok(props)
191 }
192}
193
194impl ModelConfig::FromGGUF for ModelWeights {
195 fn from_gguf<R: std::io::Seek + std::io::Read>(
196 mut ct: Content<'_, R>,
197 device: &Device,
198 mapper: Box<dyn DeviceMapper + Send + Sync>,
199 attention_mechanism: AttentionImplementation,
200 dtype: DType,
201 ) -> Result<Self> {
202 let metadata = ContentMetadata {
204 path_prefix: "starcoder2",
205 metadata: ct.get_metadata(),
206 };
207 let PropsGGUF {
208 head_count,
209 head_count_kv,
210 block_count,
211 embedding_length,
212 layer_norm_epsilon,
213 context_window,
214 rope_freq_base,
215 } = PropsGGUF::try_from(metadata).or_else(|err| candle_core::bail!("{err}"))?;
216
217 let tok_embeddings = ct.tensor("token_embd.weight", device)?;
218 let tok_embeddings = tok_embeddings.dequantize(device)?;
219 let head_dim = embedding_length / head_count;
220 let output_norm = layer_norm(
221 ct.tensor("output_norm.weight", device)?,
222 ct.tensor("output_norm.bias", device)?,
223 layer_norm_epsilon,
224 )?;
225 let output = QMatMul::from_qtensor(ct.tensor("output.weight", device)?)?;
226 let mut layers = Vec::with_capacity(block_count);
227
228 let mut ropes = HashMap::new();
229 for layer_idx in 0..block_count {
230 let device = mapper.device_for(layer_idx, false).unwrap_or(device);
231 ropes.insert(
232 device.location(),
233 Arc::new(RotaryEmbedding::new(
234 rope_freq_base,
235 head_dim,
236 context_window,
237 device,
238 true,
239 dtype,
240 )?),
241 );
242 }
243
244 for layer_idx in NiceProgressBar::<_, 'b'>(
245 0..block_count,
246 "Loading repeating layers",
247 &MultiProgress::new(),
248 ) {
249 let prefix = format!("blk.{layer_idx}");
250 let device = mapper.device_for(layer_idx, false).unwrap_or(device);
251 let rotary = ropes
252 .get(&device.location())
253 .expect("No RoPE for device location!")
254 .clone();
255
256 let ffn_up = QLinear::new(&mut ct, &format!("{prefix}.ffn_up"), device)?;
257 let ffn_down = QLinear::new(&mut ct, &format!("{prefix}.ffn_down"), device)?;
258 let QMatMul::QTensor(ffn_up_w) = ffn_up.inner_ref().clone() else {
259 unreachable!()
260 };
261 let QMatMul::QTensor(ffn_down_w) = ffn_down.inner_ref().clone() else {
262 unreachable!()
263 };
264 let mlp = Mlp {
265 ffn_up: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
266 q_weight: ffn_up_w,
267 b: ffn_up.bias().cloned(),
268 })?),
269 ffn_down: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
270 q_weight: ffn_down_w,
271 b: ffn_down.bias().cloned(),
272 })?),
273 };
274 let attn_norm = layer_norm(
275 ct.tensor(&format!("{prefix}.attn_norm.weight"), device)?,
276 ct.tensor(&format!("{prefix}.attn_norm.bias"), device)?,
277 layer_norm_epsilon,
278 )?;
279 let ffn_norm = layer_norm(
280 ct.tensor(&format!("{prefix}.ffn_norm.weight"), device)?,
281 ct.tensor(&format!("{prefix}.ffn_norm.bias"), device)?,
282 layer_norm_epsilon,
283 )?;
284 let attn_q = QLinear::new(&mut ct, &format!("{prefix}.attn_q"), device)?;
285 let attn_k = QLinear::new(&mut ct, &format!("{prefix}.attn_k"), device)?;
286 let attn_v = QLinear::new(&mut ct, &format!("{prefix}.attn_v"), device)?;
287 let attn_output = QLinear::new(&mut ct, &format!("{prefix}.attn_output"), device)?;
288 let paged_attn = match &attention_mechanism {
289 AttentionImplementation::Eager => None,
290 AttentionImplementation::PagedAttention => {
291 Some(PagedAttention::new(head_dim, device, None)?)
292 }
293 };
294 let QMatMul::QTensor(q_w) = attn_q.inner_ref().clone() else {
295 unreachable!()
296 };
297 let QMatMul::QTensor(k_w) = attn_k.inner_ref().clone() else {
298 unreachable!()
299 };
300 let QMatMul::QTensor(v_w) = attn_v.inner_ref().clone() else {
301 unreachable!()
302 };
303 let QMatMul::QTensor(o_w) = attn_output.inner_ref().clone() else {
304 unreachable!()
305 };
306 layers.push(LayerWeights {
307 attn_q: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
308 q_weight: q_w,
309 b: attn_q.bias().cloned(),
310 })?),
311 attn_k: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
312 q_weight: k_w,
313 b: attn_k.bias().cloned(),
314 })?),
315 attn_v: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
316 q_weight: v_w,
317 b: attn_v.bias().cloned(),
318 })?),
319 attn_output: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
320 q_weight: o_w,
321 b: attn_output.bias().cloned(),
322 })?),
323 attn_norm,
324 ffn_norm,
325 mlp,
326 n_head: head_count,
327 n_kv_head: head_count_kv,
328 head_dim,
329 rotary_emb: rotary,
330 paged_attn,
331 sdpa_params: SdpaParams {
332 n_kv_groups: head_count / head_count_kv,
333 use_flash_attn: false,
334 softcap: None,
335 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
336 sliding_window: None,
337 },
338 dtype,
339 })
340 }
341 Ok(Self {
342 tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
343 layers,
344 output_norm,
345 output,
346 mapper: Some(mapper),
347 device: device.clone(),
348 cache: EitherCache::Normal(NormalCache::new(block_count, context_window)),
349 max_seq_len: context_window,
350 dtype,
351 })
352 }
353}
354
355impl ModelWeights {
356 pub fn forward(
357 &self,
358 input_ids: &Tensor,
359 seqlen_offsets: &[usize],
360 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
361 ) -> Result<Tensor> {
362 let (_b_sz, seq_len) = input_ids.dims2()?;
363 let mut xs = self.tok_embeddings.forward(input_ids)?;
364 let cache = &mut self.cache.normal().0;
365 let mask = CausalMasker.make_causal_mask_matrix(
366 input_ids,
367 metadata
368 .as_ref()
369 .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
370 .unwrap_or(cache as &dyn PastKvLenCache),
371 self.dtype,
372 self.layers[0].n_head,
373 )?;
374 let mask = mask.filter(|_| {
375 metadata
376 .as_ref()
377 .map(|(_, meta)| meta.is_first_prompt_chunk)
378 .unwrap_or(true)
379 });
380 for (i, layer) in self.layers.iter().enumerate() {
381 if let Some(ref mapper) = self.mapper {
382 xs = mapper.map(xs, i)?;
383 }
384 let residual = &xs;
385 let ys = xs.apply(&layer.attn_norm)?;
386 let ys = layer.forward_attn(
387 &ys,
388 mask.as_ref()
389 .map(|m| m.to_device(xs.device()).unwrap())
390 .as_ref(),
391 seqlen_offsets,
392 &mut cache[i],
393 metadata
394 .as_ref()
395 .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
396 )?;
397 let ys = (ys + residual)?;
398 let residual = &ys;
399 let ys = ys.apply(&layer.ffn_norm)?;
400 let ys = layer.mlp.forward(&ys)?;
401 xs = (ys + residual)?
402 }
403 let xs = xs.apply(&self.output_norm)?.i((.., seq_len - 1, ..))?;
404 MatMul.qmatmul(&xs, &self.output)
405 }
406}