1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use crate::attention::SdpaParams;
7use crate::device_map::DeviceMapper;
8use crate::gguf::Content;
9use crate::layers::{CausalMasker, MatMul, QLinear, RotaryEmbedding, Sdpa};
10use crate::layers_masker::PastKvLenCache;
11use crate::paged_attention::{AttentionImplementation, PagedAttention};
12use crate::pipeline::text_models_inputs_processor::PagedAttentionInputMetadata;
13use crate::pipeline::{EitherCache, KvCache, NormalCache};
14use crate::utils::gguf_metadata::ContentMetadata;
15use crate::utils::model_config as ModelConfig;
16use crate::utils::progress::{new_multi_progress, NiceProgressBar};
17use candle_core::quantized::QMatMul;
18use candle_core::quantized::QTensor;
19use candle_core::{DType, Device, IndexOp, Module, Result, Tensor};
20use candle_nn::{Embedding, LayerNorm};
21use mistralrs_quant::{GgufMatMul, QuantMethod, QuantMethodConfig};
22
23#[derive(Clone)]
24struct Mlp {
25 ffn_up: Arc<dyn QuantMethod>,
26 ffn_down: Arc<dyn QuantMethod>,
27}
28
29impl Module for Mlp {
30 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
31 MatMul.qmethod_matmul(
32 &MatMul
33 .qmethod_matmul(xs, &*self.ffn_up)?
34 .apply(&candle_nn::Activation::GeluPytorchTanh)?,
35 &*self.ffn_down,
36 )
37 }
38}
39
40fn layer_norm(w: QTensor, b: QTensor, eps: f64) -> Result<LayerNorm> {
41 let w = w.dequantize(&w.device())?;
42 let b = b.dequantize(&b.device())?;
43 let ln = LayerNorm::new(w, b, eps);
44 Ok(ln)
45}
46
47struct LayerWeights {
48 attn_q: Arc<dyn QuantMethod>,
49 attn_k: Arc<dyn QuantMethod>,
50 attn_v: Arc<dyn QuantMethod>,
51 attn_output: Arc<dyn QuantMethod>,
52 attn_norm: LayerNorm,
53 ffn_norm: LayerNorm,
54 mlp: Mlp,
55 n_head: usize,
56 n_kv_head: usize,
57 head_dim: usize,
58 rotary_emb: Arc<RotaryEmbedding>,
59 paged_attn: Option<PagedAttention>,
60 sdpa_params: SdpaParams,
61 dtype: DType,
62}
63
64impl LayerWeights {
65 fn forward_attn(
66 &self,
67 x: &Tensor,
68 mask: Option<&Tensor>,
69 seqlen_offsets: &[usize],
70 kv_cache: &mut KvCache,
71 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
72 ) -> Result<Tensor> {
73 let (b_sz, q_len, hidden_size) = x.dims3()?;
74
75 let q = MatMul
76 .qmethod_matmul(x, &*self.attn_q)?
77 .to_dtype(self.dtype)?;
78 let k = MatMul
79 .qmethod_matmul(x, &*self.attn_k)?
80 .to_dtype(self.dtype)?;
81 let v = MatMul
82 .qmethod_matmul(x, &*self.attn_v)?
83 .to_dtype(self.dtype)?;
84
85 let (q, k, v) = if q_len != 1 {
86 let q = q
87 .reshape((b_sz, q_len, self.n_head, self.head_dim))?
88 .transpose(1, 2)?;
89 let k = k
90 .reshape((b_sz, q_len, self.n_kv_head, self.head_dim))?
91 .transpose(1, 2)?;
92 let v = v
93 .reshape((b_sz, q_len, self.n_kv_head, self.head_dim))?
94 .transpose(1, 2)?;
95 (q, k, v)
96 } else {
97 let q = q.reshape((b_sz, self.n_head, q_len, self.head_dim))?;
98 let k = k.reshape((b_sz, self.n_kv_head, q_len, self.head_dim))?;
99 let v = v.reshape((b_sz, self.n_kv_head, q_len, self.head_dim))?;
100 (q, k, v)
101 };
102
103 let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
104
105 let y = match &self.paged_attn {
106 Some(paged_attn) => {
107 let ((key_cache, value_cache), input_metadata) = metadata.unwrap();
108 paged_attn.forward(
109 &q,
110 &k,
111 &v,
112 mask,
113 Some(key_cache),
114 Some(value_cache),
115 input_metadata,
116 &self.sdpa_params,
117 None,
118 )?
119 }
120 None => {
121 let (k, v) = kv_cache.append(&k, &v)?;
122
123 Sdpa.run_attention(&q, &k, &v, mask, None, &self.sdpa_params)?
124 }
125 };
126
127 let y = if mask.is_some() {
128 y.transpose(1, 2)?.reshape(&[b_sz, q_len, hidden_size])?
129 } else {
130 y.reshape(&[b_sz, q_len, hidden_size])?
131 };
132
133 MatMul.qmethod_matmul(&y.to_dtype(x.dtype())?, &*self.attn_output)
134 }
135}
136
137pub struct ModelWeights {
138 tok_embeddings: Embedding,
139 layers: Vec<LayerWeights>,
140 output_norm: LayerNorm,
141 output: QMatMul,
142 mapper: Option<Box<dyn DeviceMapper + Send + Sync>>,
143 pub device: Device,
144 pub cache: EitherCache,
145 pub max_seq_len: usize,
146 dtype: DType,
147}
148
149pub(crate) struct PropsGGUF {
152 pub head_count: usize,
153 pub head_count_kv: usize,
154 pub block_count: usize,
155 pub embedding_length: usize,
156 pub layer_norm_epsilon: f64,
157 pub context_window: usize,
158 pub rope_freq_base: f32,
159}
160
161impl TryFrom<ContentMetadata<'_>> for PropsGGUF {
162 type Error = anyhow::Error;
163
164 fn try_from(c: ContentMetadata) -> std::result::Result<Self, Self::Error> {
165 c.verify_arch("starcoder2")?;
166
167 let required = [
168 "attention.head_count",
169 "attention.head_count_kv",
170 "block_count",
171 "embedding_length",
172 "attention.layer_norm_epsilon",
173 "context_length",
174 ];
175 c.has_required_keys(&required)?;
176
177 let props = Self {
180 head_count: c.get_value::<u32>("attention.head_count")? as usize,
181 head_count_kv: c.get_value::<u32>("attention.head_count_kv")? as usize,
182 block_count: c.get_value::<u32>("block_count")? as usize,
183 embedding_length: c.get_value::<u32>("embedding_length")? as usize,
184 layer_norm_epsilon: c.get_value::<f32>("attention.layer_norm_epsilon")? as f64,
185 context_window: c.get_value::<u32>("context_length")? as usize,
186 rope_freq_base: c.get_value("rope.freq_base").ok().unwrap_or(100_000_f32),
187 };
188
189 Ok(props)
190 }
191}
192
193impl ModelConfig::FromGGUF for ModelWeights {
194 fn from_gguf<R: std::io::Seek + std::io::Read>(
195 mut ct: Content<'_, R>,
196 device: &Device,
197 mapper: Box<dyn DeviceMapper + Send + Sync>,
198 attention_mechanism: AttentionImplementation,
199 dtype: DType,
200 ) -> Result<Self> {
201 let metadata = ContentMetadata {
203 path_prefix: "starcoder2",
204 metadata: ct.get_metadata(),
205 };
206 let PropsGGUF {
207 head_count,
208 head_count_kv,
209 block_count,
210 embedding_length,
211 layer_norm_epsilon,
212 context_window,
213 rope_freq_base,
214 } = PropsGGUF::try_from(metadata).or_else(|err| candle_core::bail!("{err}"))?;
215
216 let tok_embeddings = ct.tensor("token_embd.weight", device)?;
217 let tok_embeddings = tok_embeddings.dequantize(device)?;
218 let head_dim = embedding_length / head_count;
219 let output_norm = layer_norm(
220 ct.tensor("output_norm.weight", device)?,
221 ct.tensor("output_norm.bias", device)?,
222 layer_norm_epsilon,
223 )?;
224 let output = QMatMul::from_qtensor(ct.tensor("output.weight", device)?)?;
225 let mut layers = Vec::with_capacity(block_count);
226
227 let mut ropes = HashMap::new();
228 for layer_idx in 0..block_count {
229 let device = mapper.device_for(layer_idx, false).unwrap_or(device);
230 ropes.insert(
231 device.location(),
232 Arc::new(RotaryEmbedding::new(
233 rope_freq_base,
234 head_dim,
235 context_window,
236 device,
237 true,
238 dtype,
239 )?),
240 );
241 }
242
243 for layer_idx in NiceProgressBar::<_, 'b'>(
244 0..block_count,
245 "Loading repeating layers",
246 &new_multi_progress(),
247 ) {
248 let prefix = format!("blk.{layer_idx}");
249 let device = mapper.device_for(layer_idx, false).unwrap_or(device);
250 let rotary = ropes
251 .get(&device.location())
252 .expect("No RoPE for device location!")
253 .clone();
254
255 let ffn_up = QLinear::new(&mut ct, &format!("{prefix}.ffn_up"), device)?;
256 let ffn_down = QLinear::new(&mut ct, &format!("{prefix}.ffn_down"), device)?;
257 let QMatMul::QTensor(ffn_up_w) = ffn_up.inner_ref().clone() else {
258 unreachable!()
259 };
260 let QMatMul::QTensor(ffn_down_w) = ffn_down.inner_ref().clone() else {
261 unreachable!()
262 };
263 let mlp = Mlp {
264 ffn_up: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
265 q_weight: ffn_up_w,
266 b: ffn_up.bias().cloned(),
267 })?),
268 ffn_down: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
269 q_weight: ffn_down_w,
270 b: ffn_down.bias().cloned(),
271 })?),
272 };
273 let attn_norm = layer_norm(
274 ct.tensor(&format!("{prefix}.attn_norm.weight"), device)?,
275 ct.tensor(&format!("{prefix}.attn_norm.bias"), device)?,
276 layer_norm_epsilon,
277 )?;
278 let ffn_norm = layer_norm(
279 ct.tensor(&format!("{prefix}.ffn_norm.weight"), device)?,
280 ct.tensor(&format!("{prefix}.ffn_norm.bias"), device)?,
281 layer_norm_epsilon,
282 )?;
283 let attn_q = QLinear::new(&mut ct, &format!("{prefix}.attn_q"), device)?;
284 let attn_k = QLinear::new(&mut ct, &format!("{prefix}.attn_k"), device)?;
285 let attn_v = QLinear::new(&mut ct, &format!("{prefix}.attn_v"), device)?;
286 let attn_output = QLinear::new(&mut ct, &format!("{prefix}.attn_output"), device)?;
287 let paged_attn = match &attention_mechanism {
288 AttentionImplementation::Eager => None,
289 AttentionImplementation::PagedAttention => {
290 Some(PagedAttention::new(head_dim, device, None)?)
291 }
292 };
293 let QMatMul::QTensor(q_w) = attn_q.inner_ref().clone() else {
294 unreachable!()
295 };
296 let QMatMul::QTensor(k_w) = attn_k.inner_ref().clone() else {
297 unreachable!()
298 };
299 let QMatMul::QTensor(v_w) = attn_v.inner_ref().clone() else {
300 unreachable!()
301 };
302 let QMatMul::QTensor(o_w) = attn_output.inner_ref().clone() else {
303 unreachable!()
304 };
305 layers.push(LayerWeights {
306 attn_q: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
307 q_weight: q_w,
308 b: attn_q.bias().cloned(),
309 })?),
310 attn_k: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
311 q_weight: k_w,
312 b: attn_k.bias().cloned(),
313 })?),
314 attn_v: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
315 q_weight: v_w,
316 b: attn_v.bias().cloned(),
317 })?),
318 attn_output: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
319 q_weight: o_w,
320 b: attn_output.bias().cloned(),
321 })?),
322 attn_norm,
323 ffn_norm,
324 mlp,
325 n_head: head_count,
326 n_kv_head: head_count_kv,
327 head_dim,
328 rotary_emb: rotary,
329 paged_attn,
330 sdpa_params: SdpaParams {
331 n_kv_groups: head_count / head_count_kv,
332 softcap: None,
333 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
334 sliding_window: None,
335 },
336 dtype,
337 })
338 }
339 Ok(Self {
340 tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
341 layers,
342 output_norm,
343 output,
344 mapper: Some(mapper),
345 device: device.clone(),
346 cache: EitherCache::Normal(NormalCache::new(block_count, context_window)),
347 max_seq_len: context_window,
348 dtype,
349 })
350 }
351}
352
353impl ModelWeights {
354 pub fn forward(
355 &self,
356 input_ids: &Tensor,
357 seqlen_offsets: &[usize],
358 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
359 ) -> Result<Tensor> {
360 let (_b_sz, seq_len) = input_ids.dims2()?;
361 let mut xs = self.tok_embeddings.forward(input_ids)?;
362 let cache = &mut self.cache.normal().0;
363 let mask = CausalMasker.make_causal_mask_matrix(
364 input_ids,
365 metadata
366 .as_ref()
367 .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
368 .unwrap_or(cache as &dyn PastKvLenCache),
369 self.dtype,
370 self.layers[0].n_head,
371 )?;
372 let mask = mask.filter(|_| {
373 metadata
374 .as_ref()
375 .map(|(_, meta)| meta.is_first_prompt_chunk)
376 .unwrap_or(true)
377 });
378 for (i, layer) in self.layers.iter().enumerate() {
379 if let Some(ref mapper) = self.mapper {
380 xs = mapper.map(xs, i)?;
381 }
382 let residual = &xs;
383 let ys = xs.apply(&layer.attn_norm)?;
384 let ys = layer.forward_attn(
385 &ys,
386 mask.as_ref()
387 .map(|m| m.to_device(xs.device()).unwrap())
388 .as_ref(),
389 seqlen_offsets,
390 &mut cache[i],
391 metadata
392 .as_ref()
393 .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
394 )?;
395 let ys = (ys + residual)?;
396 let residual = &ys;
397 let ys = ys.apply(&layer.ffn_norm)?;
398 let ys = layer.mlp.forward(&ys)?;
399 xs = (ys + residual)?
400 }
401 let xs = xs.apply(&self.output_norm)?.i((.., seq_len - 1, ..))?;
402 MatMul.qmatmul(&xs, &self.output)
403 }
404}