1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use candle_core::{DType, Device, Result, Tensor};
7use candle_nn::{Embedding, Module};
8use indicatif::MultiProgress;
9use mistralrs_quant::{GgufMatMul, QuantMethod, QuantMethodConfig};
10
11use crate::attention::SdpaParams;
12use crate::device_map::DeviceMapper;
13use crate::gguf::Content;
14use crate::layers::{CausalMasker, MatMul, QRmsNorm, RotaryEmbedding, Sdpa};
15use crate::layers_masker::PastKvLenCache;
16use crate::paged_attention::{AttentionImplementation, PagedAttention};
17use crate::pipeline::text_models_inputs_processor::PagedAttentionInputMetadata;
18use crate::pipeline::{extract_logits, EitherCache, KvCache, NormalCache};
19use crate::utils::gguf_metadata::ContentMetadata;
20use crate::utils::model_config as ModelConfig;
21use crate::utils::progress::NiceProgressBar;
22const MAX_SEQ_LEN: u32 = 4096;
23
24struct Mlp {
25 feed_forward_w1: Arc<dyn QuantMethod>,
26 feed_forward_w2: Arc<dyn QuantMethod>,
27 feed_forward_w3: Arc<dyn QuantMethod>,
28}
29
30impl Mlp {
31 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
32 let w1 = MatMul.qmethod_matmul(xs, &*self.feed_forward_w1)?;
33 let w3 = MatMul.qmethod_matmul(xs, &*self.feed_forward_w3)?;
34 let y = &(candle_nn::ops::silu(&w1)? * w3)?;
35 MatMul.qmethod_matmul(y, &*self.feed_forward_w2)
36 }
37}
38
39struct LayerWeights {
40 attention_wq: Arc<dyn QuantMethod>,
41 attention_wk: Arc<dyn QuantMethod>,
42 attention_wv: Arc<dyn QuantMethod>,
43 attention_wo: Arc<dyn QuantMethod>,
44 attention_norm: QRmsNorm,
45 mlp: Mlp,
46 ffn_norm: QRmsNorm,
47 n_head: usize,
48 n_kv_head: usize,
49 head_dim: usize,
50 rotary: Arc<RotaryEmbedding>,
51 paged_attn: Option<PagedAttention>,
52 sdpa_params: SdpaParams,
53 dtype: DType,
54}
55
56impl LayerWeights {
57 fn forward_attn(
58 &self,
59 x: &Tensor,
60 mask: Option<&Tensor>,
61 start_offsets: &[usize],
62 kv_cache: &mut KvCache,
63 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
64 ) -> Result<Tensor> {
65 let (b_sz, seq_len, _) = x.dims3()?;
66
67 let q = MatMul
68 .qmethod_matmul(x, &*self.attention_wq)?
69 .to_dtype(self.dtype)?;
70 let k = MatMul
71 .qmethod_matmul(x, &*self.attention_wk)?
72 .to_dtype(self.dtype)?;
73 let v = MatMul
74 .qmethod_matmul(x, &*self.attention_wv)?
75 .to_dtype(self.dtype)?;
76
77 let (q, k, v) = if seq_len != 1 {
78 let q = q
79 .reshape((b_sz, seq_len, self.n_head, self.head_dim))?
80 .transpose(1, 2)?;
81 let k = k
82 .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
83 .transpose(1, 2)?;
84 let v = v
85 .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
86 .transpose(1, 2)?;
87 (q, k, v)
88 } else {
89 let q = q.reshape((b_sz, self.n_head, seq_len, self.head_dim))?;
90 let k = k.reshape((b_sz, self.n_kv_head, seq_len, self.head_dim))?;
91 let v = v.reshape((b_sz, self.n_kv_head, seq_len, self.head_dim))?;
92 (q, k, v)
93 };
94
95 let (q, k) = self.rotary.forward(&q, &k, start_offsets)?;
96
97 let y = match &self.paged_attn {
98 Some(paged_attn) => {
99 let ((key_cache, value_cache), input_metadata) = metadata.unwrap();
100 paged_attn.forward(
101 &q,
102 &k,
103 &v,
104 mask,
105 Some(key_cache),
106 Some(value_cache),
107 input_metadata,
108 &self.sdpa_params,
109 None,
110 )?
111 }
112 None => {
113 let (k, v) = kv_cache.append(&k, &v)?;
114
115 Sdpa.run_attention(&q, &k, &v, mask, None, &self.sdpa_params)?
116 }
117 };
118
119 let y = if mask.is_some() {
120 y.transpose(1, 2)?.reshape((b_sz, seq_len, ()))?
121 } else {
122 y.reshape((b_sz, seq_len, ()))?
123 };
124
125 let y = MatMul.qmethod_matmul(&y.to_dtype(x.dtype())?, &*self.attention_wo)?;
126 Ok(y)
127 }
128}
129
130pub struct ModelWeights {
131 tok_embeddings: Embedding,
132 layers: Vec<LayerWeights>,
133 norm: QRmsNorm,
134 output: Arc<dyn QuantMethod>,
135 pub device: Device,
136 pub cache: EitherCache,
137 pub max_seq_len: usize,
138 mapper: Option<Box<dyn DeviceMapper + Send + Sync>>,
139 dtype: DType,
140}
141
142pub(crate) struct PropsGGUF {
146 pub head_count: usize,
147 pub head_count_kv: usize,
148 pub block_count: usize,
149 pub embedding_length: usize,
150 pub rms_norm_eps: f32,
151 pub max_seq_len: usize,
152 pub rope_freq_base: f32,
153 pub key_length: usize,
154 pub value_length: usize,
155}
156
157impl TryFrom<ContentMetadata<'_>> for PropsGGUF {
158 type Error = anyhow::Error;
159
160 fn try_from(c: ContentMetadata) -> std::result::Result<Self, Self::Error> {
161 c.verify_arch("qwen2")?;
162
163 let required = [
164 "attention.head_count",
165 "attention.head_count_kv",
166 "block_count",
167 "embedding_length",
168 "attention.layer_norm_rms_epsilon",
169 ];
170 c.has_required_keys(&required)?;
171
172 let embed_len = c.get_value::<u32>("embedding_length")? as usize;
173 let head_count = c.get_value::<u32>("attention.head_count")? as usize;
174
175 let props = Self {
178 head_count,
179 head_count_kv: c.get_value::<u32>("attention.head_count_kv")? as usize,
180 block_count: c.get_value::<u32>("block_count")? as usize,
181 embedding_length: embed_len,
182 rms_norm_eps: c.get_value("attention.layer_norm_rms_epsilon")?,
184 max_seq_len: c
185 .get_value::<u64>("context_length")
186 .ok()
187 .unwrap_or(MAX_SEQ_LEN as u64) as usize,
188 rope_freq_base: c.get_value("rope.freq_base").ok().unwrap_or(10_000_f32),
189 key_length: c
190 .get_value::<u32>("attention.key_length")
191 .ok()
192 .map(|x| x as usize)
193 .unwrap_or(embed_len / head_count),
194 value_length: c
195 .get_value::<u32>("attention.value_length")
196 .ok()
197 .map(|x| x as usize)
198 .unwrap_or(embed_len / head_count),
199 };
200
201 Ok(props)
202 }
203}
204
205impl ModelConfig::FromGGUF for ModelWeights {
206 fn from_gguf<R: std::io::Seek + std::io::Read>(
207 mut ct: Content<'_, R>,
208 device: &Device,
209 mapper: Box<dyn DeviceMapper + Send + Sync>,
210 attention_mechanism: AttentionImplementation,
211 dtype: DType,
212 ) -> Result<Self> {
213 let metadata = ContentMetadata {
215 path_prefix: "qwen2",
216 metadata: ct.get_metadata(),
217 };
218 let PropsGGUF {
219 head_count,
220 head_count_kv,
221 block_count,
222 embedding_length,
223 rms_norm_eps,
224 max_seq_len,
225 rope_freq_base,
226 key_length,
227 value_length,
228 } = PropsGGUF::try_from(metadata).or_else(|err| candle_core::bail!("{err}"))?;
229
230 let qtok_embeddings = ct.tensor("token_embd.weight", device)?;
231 let tok_embeddings = qtok_embeddings.dequantize(device)?;
232 let norm = QRmsNorm::new(ct.tensor("output_norm.weight", device)?, rms_norm_eps)?;
233 let output = if !ct.has_tensor("output.weight") {
234 ct.tensor("token_embd.weight", device)?
235 } else {
236 ct.tensor("output.weight", device)?
237 };
238 let mut layers = Vec::with_capacity(block_count);
239
240 let head_dim = key_length;
241 if key_length != value_length {
242 candle_core::bail!(
243 "Expected key_length == value_length, got {key_length} != {value_length}"
244 );
245 }
246
247 let mut ropes = HashMap::new();
248 for layer_idx in 0..block_count {
249 let device = mapper.device_for(layer_idx, false).unwrap_or(device);
250 ropes.insert(
251 device.location(),
252 Arc::new(RotaryEmbedding::new(
253 rope_freq_base,
254 head_dim,
255 max_seq_len,
256 device,
257 true,
258 dtype,
259 )?),
260 );
261 }
262
263 for layer_idx in NiceProgressBar::<_, 'b'>(
264 0..block_count,
265 "Loading repeating layers",
266 &MultiProgress::new(),
267 ) {
268 let prefix = format!("blk.{layer_idx}");
269 let device = mapper.device_for(layer_idx, false).unwrap_or(device);
270 let rotary = ropes
271 .get(&device.location())
272 .expect("No RoPE for device location!")
273 .clone();
274
275 let attention_wq = ct.tensor(&format!("{prefix}.attn_q.weight"), device)?;
276 let attention_bias_q = ct
277 .tensor(&format!("{prefix}.attn_q.bias"), device)?
278 .dequantize(device)?;
279 let attention_wk = ct.tensor(&format!("{prefix}.attn_k.weight"), device)?;
280 let attention_bias_k = ct
281 .tensor(&format!("{prefix}.attn_k.bias"), device)?
282 .dequantize(device)?;
283 let attention_wv = ct.tensor(&format!("{prefix}.attn_v.weight"), device)?;
284 let attention_bias_v = ct
285 .tensor(&format!("{prefix}.attn_v.bias"), device)?
286 .dequantize(device)?;
287 let attention_wo = ct.tensor(&format!("{prefix}.attn_output.weight"), device)?;
288
289 let feed_forward_w1 = ct.tensor(&format!("{prefix}.ffn_gate.weight"), device)?;
290 let feed_forward_w2 = ct.tensor(&format!("{prefix}.ffn_down.weight"), device)?;
291 let feed_forward_w3 = ct.tensor(&format!("{prefix}.ffn_up.weight"), device)?;
292 let mlp = Mlp {
293 feed_forward_w1: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
294 q_weight: Arc::new(feed_forward_w1),
295 b: None,
296 })?),
297 feed_forward_w2: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
298 q_weight: Arc::new(feed_forward_w2),
299 b: None,
300 })?),
301 feed_forward_w3: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
302 q_weight: Arc::new(feed_forward_w3),
303 b: None,
304 })?),
305 };
306
307 let attention_norm = ct.tensor(&format!("{prefix}.attn_norm.weight"), device)?;
308 let ffn_norm = ct.tensor(&format!("{prefix}.ffn_norm.weight"), device)?;
309 let paged_attn = match &attention_mechanism {
310 AttentionImplementation::Eager => None,
311 AttentionImplementation::PagedAttention => {
312 Some(PagedAttention::new(head_dim, device, None)?)
313 }
314 };
315 layers.push(LayerWeights {
316 attention_wq: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
317 q_weight: Arc::new(attention_wq),
318 b: Some(attention_bias_q),
319 })?),
320 attention_wk: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
321 q_weight: Arc::new(attention_wk),
322 b: Some(attention_bias_k),
323 })?),
324 attention_wv: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
325 q_weight: Arc::new(attention_wv),
326 b: Some(attention_bias_v),
327 })?),
328 attention_wo: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
329 q_weight: Arc::new(attention_wo),
330 b: None,
331 })?),
332 attention_norm: QRmsNorm::new(attention_norm, rms_norm_eps)?,
333 mlp,
334 ffn_norm: QRmsNorm::new(ffn_norm, rms_norm_eps)?,
335 n_head: head_count,
336 n_kv_head: head_count_kv,
337 head_dim,
338 rotary: rotary.clone(),
339 paged_attn,
340 sdpa_params: SdpaParams {
341 n_kv_groups: head_count / head_count_kv,
342 use_flash_attn: false,
343 softcap: None,
344 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
345 sliding_window: None,
346 },
347 dtype,
348 })
349 }
350 Ok(Self {
351 tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
352 layers,
353 norm,
354 output: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
355 q_weight: Arc::new(output),
356 b: None,
357 })?),
358 device: device.clone(),
359 cache: EitherCache::Normal(NormalCache::new(block_count, max_seq_len)),
360 max_seq_len,
361 mapper: Some(mapper),
362 dtype,
363 })
364 }
365}
366
367impl ModelWeights {
368 pub fn forward(
369 &self,
370 x: &Tensor,
371 start_offsets: &[usize],
372 context_lens: Vec<(usize, usize)>,
373 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
374 ) -> Result<Tensor> {
375 let mut layer_in = self.tok_embeddings.forward(x)?;
376 let cache = &mut self.cache.normal().0;
377 let mask = CausalMasker.make_causal_mask_matrix(
378 x,
379 metadata
380 .as_ref()
381 .map(|(_, _)| &start_offsets as &dyn PastKvLenCache)
382 .unwrap_or(cache as &dyn PastKvLenCache),
383 self.dtype,
384 self.layers[0].n_head,
385 )?;
386 let mask = mask.filter(|_| {
387 metadata
388 .as_ref()
389 .map(|(_, meta)| meta.is_first_prompt_chunk)
390 .unwrap_or(true)
391 });
392 for (i, layer) in self.layers.iter().enumerate() {
393 if let Some(ref mapper) = self.mapper {
394 layer_in = mapper.map(layer_in, i)?;
395 }
396 let x = layer_in;
397 let residual = &x;
398 let x = layer.attention_norm.forward(&x)?;
399 let attn = layer.forward_attn(
400 &x,
401 mask.as_ref()
402 .map(|m| m.to_device(x.device()).unwrap())
403 .as_ref(),
404 start_offsets,
405 &mut cache[i],
406 metadata
407 .as_ref()
408 .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
409 )?;
410 let x = (attn + residual)?;
411
412 let residual = &x;
414 let x = layer.ffn_norm.forward(&x)?;
415 let x = layer.mlp.forward(&x)?;
416 let x = (x + residual)?;
417 layer_in = x;
418 }
419 let x = self.norm.forward(&layer_in)?;
420 extract_logits(
421 &MatMul.qmethod_matmul(&x.contiguous()?, &*self.output)?,
422 context_lens,
423 )
424 }
425}