1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::sync::Arc;
4
5use crate::attention::SdpaParams;
6use crate::device_map::DeviceMapper;
7use crate::gguf::Content;
8use crate::layers::{CausalMasker, MatMul, RmsNorm, Sdpa};
9use crate::layers_masker::PastKvLenCache;
10use crate::paged_attention::{AttentionImplementation, PagedAttention};
11use crate::pipeline::text_models_inputs_processor::PagedAttentionInputMetadata;
12use crate::pipeline::{EitherCache, KvCache, NormalCache};
13use crate::utils::gguf_metadata::ContentMetadata;
14use crate::utils::model_config as ModelConfig;
15use crate::utils::progress::NiceProgressBar;
16use candle_core::quantized::QMatMul;
17use candle_core::quantized::QTensor;
18use candle_core::{DType, Device, IndexOp, Module, Result, Tensor, D};
19use candle_nn::Embedding;
20use indicatif::MultiProgress;
21use mistralrs_quant::{GgufMatMul, QuantMethod, QuantMethodConfig};
22
23#[derive(Clone)]
24struct Mlp {
25 ffn_up: Arc<dyn QuantMethod>,
26 ffn_down: Arc<dyn QuantMethod>,
27 i_size: usize,
28}
29
30impl Module for Mlp {
31 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
32 let up_states = MatMul.qmethod_matmul(xs, &*self.ffn_up)?;
33 let gate = up_states.narrow(D::Minus1, 0, self.i_size)?;
34 let up_states = up_states.narrow(D::Minus1, self.i_size, self.i_size)?;
35 let up_states = (up_states * gate.silu()?)?;
36 MatMul.qmethod_matmul(&up_states, &*self.ffn_down)
37 }
38}
39
40fn rms_norm(w: QTensor, eps: f64) -> Result<RmsNorm> {
41 let w = w.dequantize(&w.device())?;
42 let rms = RmsNorm::from_w(w, eps)?;
43 Ok(rms)
44}
45
46struct LayerWeights {
47 attn_qkv: Arc<dyn QuantMethod>,
48 attn_output: Arc<dyn QuantMethod>,
49 attn_norm: RmsNorm,
50 ffn_norm: RmsNorm,
51 mlp: Mlp,
52 n_head: usize,
53 n_kv_head: usize,
54 head_dim: usize,
55 cos: Tensor,
56 sin: Tensor,
57 paged_attn: Option<PagedAttention>,
58 sdpa_params: SdpaParams,
59 dtype: DType,
60}
61
62impl LayerWeights {
63 fn apply_rotary_emb(&self, xs: &Tensor, seqlen_offsets: &[usize]) -> Result<Tensor> {
64 let (_b_sz, _h, seq_len, _n_embd) = xs.dims4()?;
65 let mut outputs = Vec::new();
66 for (i, offset) in seqlen_offsets.iter().enumerate() {
67 let cos = self.cos.narrow(0, *offset, seq_len)?;
68 let sin = self.sin.narrow(0, *offset, seq_len)?;
69 outputs.push(candle_nn::rotary_emb::rope(
70 &xs.i(i)?.unsqueeze(0)?.contiguous()?,
71 &cos,
72 &sin,
73 )?);
74 }
75 Tensor::cat(&outputs, 0)
76 }
77
78 fn forward_attn(
79 &self,
80 x: &Tensor,
81 mask: Option<&Tensor>,
82 seqlen_offsets: &[usize],
83 kv_cache: &mut KvCache,
84 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
85 ) -> Result<Tensor> {
86 let (b_sz, seq_len, _) = x.dims3()?;
87 let qkv = MatMul
88 .qmethod_matmul(x, &*self.attn_qkv)?
89 .to_dtype(self.dtype)?;
90 let query_pos = self.n_head * self.head_dim;
91 let q = qkv.narrow(D::Minus1, 0, query_pos)?;
92 let k = qkv.narrow(D::Minus1, query_pos, self.n_kv_head * self.head_dim)?;
93 let v = qkv.narrow(
94 D::Minus1,
95 query_pos + self.n_kv_head * self.head_dim,
96 self.n_kv_head * self.head_dim,
97 )?;
98
99 let (q, k, v) = if seq_len != 1 {
100 let q = q
101 .reshape((b_sz, seq_len, self.n_head, self.head_dim))?
102 .transpose(1, 2)?;
103 let k = k
104 .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
105 .transpose(1, 2)?;
106 let v = v
107 .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
108 .transpose(1, 2)?;
109 (q, k, v.contiguous()?)
110 } else {
111 let q = q.reshape((b_sz, self.n_head, seq_len, self.head_dim))?;
112 let k = k.reshape((b_sz, self.n_kv_head, seq_len, self.head_dim))?;
113 let v = v.reshape((b_sz, self.n_kv_head, seq_len, self.head_dim))?;
114 (q, k, v)
115 };
116
117 let q = self.apply_rotary_emb(&q, seqlen_offsets)?;
118 let k = self.apply_rotary_emb(&k, seqlen_offsets)?;
119 let y = match &self.paged_attn {
120 Some(paged_attn) => {
121 let ((key_cache, value_cache), input_metadata) = metadata.unwrap();
122 paged_attn.forward(
123 &q,
124 &k,
125 &v,
126 mask,
127 Some(key_cache),
128 Some(value_cache),
129 input_metadata,
130 &self.sdpa_params,
131 None,
132 )?
133 }
134 None => {
135 let (k, v) = kv_cache.append(&k, &v)?;
136
137 Sdpa.run_attention(&q, &k, &v, mask, None, &self.sdpa_params)?
138 }
139 };
140
141 let y = if mask.is_some() {
142 y.transpose(1, 2)?.reshape((b_sz, seq_len, ()))?
143 } else {
144 y.reshape((b_sz, seq_len, ()))?
145 };
146 let y = MatMul.qmethod_matmul(&y.to_dtype(x.dtype())?, &*self.attn_output)?;
147 Ok(y)
148 }
149}
150
151pub struct ModelWeights {
152 tok_embeddings: Embedding,
153 layers: Vec<LayerWeights>,
154 output_norm: RmsNorm,
155 output: QMatMul,
156 mapper: Option<Box<dyn DeviceMapper + Send + Sync>>,
157 pub device: Device,
158 pub cache: EitherCache,
159 pub max_seq_len: usize,
160 dtype: DType,
161}
162
163fn precomput_freqs_cis(
164 head_dim: usize,
165 freq_base: f32,
166 device: &Device,
167 context_window: usize,
168 dtype: DType,
169) -> Result<(Tensor, Tensor)> {
170 let theta: Vec<_> = (0..head_dim)
171 .step_by(2)
172 .map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32))
173 .collect();
174 let theta = Tensor::new(theta.as_slice(), device)?;
175 let idx_theta = Tensor::arange(0, context_window as u32, device)?
176 .to_dtype(DType::F32)?
177 .reshape((context_window, 1))?
178 .matmul(&theta.reshape((1, theta.elem_count()))?)?;
179 let cos = idx_theta.cos()?.to_dtype(dtype)?;
180 let sin = idx_theta.sin()?.to_dtype(dtype)?;
181 Ok((cos, sin))
182}
183
184pub(crate) struct PropsGGUF {
188 pub head_count: usize,
189 pub head_count_kv: usize,
190 pub block_count: usize,
191 pub embedding_length: usize,
192 pub i_size: usize,
193 pub rope_dim: usize,
194 pub rms_eps: f64,
195 pub context_window: usize,
196}
197
198impl TryFrom<ContentMetadata<'_>> for PropsGGUF {
199 type Error = anyhow::Error;
200
201 fn try_from(c: ContentMetadata) -> std::result::Result<Self, Self::Error> {
202 c.verify_arch("phi3")?;
203
204 let required = [
205 "attention.head_count",
206 "attention.head_count_kv",
207 "block_count",
208 "embedding_length",
209 "feed_forward_length",
210 "rope.dimension_count",
211 "attention.layer_norm_rms_epsilon",
212 "context_length",
213 ];
214 c.has_required_keys(&required)?;
215
216 let props = Self {
219 head_count: c.get_value::<u32>("attention.head_count")? as usize,
220 head_count_kv: c.get_value::<u32>("attention.head_count_kv")? as usize,
221 block_count: c.get_value::<u32>("block_count")? as usize,
222 embedding_length: c.get_value::<u32>("embedding_length")? as usize,
223 i_size: c.get_value::<u32>("feed_forward_length")? as usize,
224 rope_dim: c.get_value::<u32>("rope.dimension_count")? as usize,
225 rms_eps: c.get_value::<f32>("attention.layer_norm_rms_epsilon")? as f64,
226 context_window: c.get_value::<u32>("context_length")? as usize,
227 };
228
229 Ok(props)
230 }
231}
232
233impl ModelConfig::FromGGUF for ModelWeights {
234 fn from_gguf<R: std::io::Seek + std::io::Read>(
235 mut ct: Content<'_, R>,
236 device: &Device,
237 mapper: Box<dyn DeviceMapper + Send + Sync>,
238 attention_mechanism: AttentionImplementation,
239 dtype: DType,
240 ) -> Result<Self> {
241 let metadata = ContentMetadata {
243 path_prefix: "phi3",
244 metadata: ct.get_metadata(),
245 };
246 let PropsGGUF {
247 head_count,
248 head_count_kv,
249 block_count,
250 embedding_length,
251 i_size,
252 rope_dim,
253 rms_eps,
254 context_window,
255 } = PropsGGUF::try_from(metadata).or_else(|err| candle_core::bail!("{err}"))?;
256
257 let (cos, sin) = precomput_freqs_cis(rope_dim, 10_000., device, context_window, dtype)?;
258
259 let tok_embeddings = ct.tensor("token_embd.weight", device)?;
260 let tok_embeddings = tok_embeddings.dequantize(device)?;
261 let output_norm = rms_norm(ct.tensor("output_norm.weight", device)?, rms_eps)?;
262 let output = QMatMul::from_qtensor(ct.tensor("output.weight", device)?)?;
263 let mut layers = Vec::with_capacity(block_count);
264 let head_dim = embedding_length / head_count;
265
266 for layer_idx in NiceProgressBar::<_, 'b'>(
267 0..block_count,
268 "Loading repeating layers",
269 &MultiProgress::new(),
270 ) {
271 let prefix = format!("blk.{layer_idx}");
272 let device = mapper.device_for(layer_idx, false).unwrap_or(device);
273 let ffn_up =
274 QMatMul::from_qtensor(ct.tensor(&format!("{prefix}.ffn_up.weight"), device)?)?;
275 let ffn_down =
276 QMatMul::from_qtensor(ct.tensor(&format!("{prefix}.ffn_down.weight"), device)?)?;
277 let QMatMul::QTensor(ffn_up_w) = ffn_up else {
278 unreachable!()
279 };
280 let QMatMul::QTensor(ffn_down_w) = ffn_down else {
281 unreachable!()
282 };
283 let mlp = Mlp {
284 ffn_up: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
285 q_weight: ffn_up_w,
286 b: None,
287 })?),
288 ffn_down: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
289 q_weight: ffn_down_w,
290 b: None,
291 })?),
292 i_size,
293 };
294 let attn_norm = rms_norm(
295 ct.tensor(&format!("{prefix}.attn_norm.weight"), device)?,
296 rms_eps,
297 )?;
298 let ffn_norm = rms_norm(
299 ct.tensor(&format!("{prefix}.ffn_norm.weight"), device)?,
300 rms_eps,
301 )?;
302 let paged_attn = match &attention_mechanism {
303 AttentionImplementation::Eager => None,
304 AttentionImplementation::PagedAttention => {
305 Some(PagedAttention::new(head_dim, device, None)?)
306 }
307 };
308 let qkv =
309 QMatMul::from_qtensor(ct.tensor(&format!("{prefix}.attn_qkv.weight"), device)?)?;
310 let out =
311 QMatMul::from_qtensor(ct.tensor(&format!("{prefix}.attn_output.weight"), device)?)?;
312 let QMatMul::QTensor(qkv_w) = qkv.clone() else {
313 unreachable!()
314 };
315 let QMatMul::QTensor(out_w) = out.clone() else {
316 unreachable!()
317 };
318 layers.push(LayerWeights {
319 attn_qkv: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
320 q_weight: qkv_w,
321 b: None,
322 })?),
323 attn_output: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
324 q_weight: out_w,
325 b: None,
326 })?),
327 attn_norm,
328 ffn_norm,
329 mlp,
330 n_head: head_count,
331 n_kv_head: head_count_kv,
332 head_dim,
333 cos: cos.to_device(device)?,
334 sin: sin.to_device(device)?,
335 paged_attn,
336 sdpa_params: SdpaParams {
337 n_kv_groups: head_count / head_count_kv,
338 use_flash_attn: false,
339 softcap: None,
340 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
341 sliding_window: Some(context_window),
342 },
343 dtype,
344 })
345 }
346 Ok(Self {
347 tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
348 layers,
349 output_norm,
350 output,
351 mapper: Some(mapper),
352 device: device.clone(),
353 cache: EitherCache::Normal(NormalCache::new_sliding(
354 block_count,
355 context_window,
356 Some(context_window),
357 )),
358 max_seq_len: context_window,
359 dtype,
360 })
361 }
362}
363
364impl ModelWeights {
365 pub fn forward(
366 &self,
367 input_ids: &Tensor,
368 seqlen_offsets: &[usize],
369 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
370 ) -> Result<Tensor> {
371 let (_b_sz, seq_len) = input_ids.dims2()?;
372 let mut xs = self.tok_embeddings.forward(input_ids)?;
373 let cache = &mut self.cache.normal().0;
374 let mask = CausalMasker.make_sliding_window_causal_mask_matrix(
375 input_ids,
376 metadata
377 .as_ref()
378 .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
379 .unwrap_or(cache as &dyn PastKvLenCache),
380 Some(self.max_seq_len),
381 self.dtype,
382 self.layers[0].n_head,
383 )?;
384 let mask = mask.filter(|_| {
385 metadata
386 .as_ref()
387 .map(|(_, meta)| meta.is_first_prompt_chunk)
388 .unwrap_or(true)
389 });
390 for (i, layer) in self.layers.iter().enumerate() {
391 if let Some(ref mapper) = self.mapper {
392 xs = mapper.map(xs, i)?;
393 }
394 let residual = &xs;
395 let ys = xs.apply(&layer.attn_norm)?;
396 let ys = layer.forward_attn(
397 &ys,
398 mask.as_ref()
399 .map(|m| m.to_device(xs.device()).unwrap())
400 .as_ref(),
401 seqlen_offsets,
402 &mut cache[i],
403 metadata
404 .as_ref()
405 .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
406 )?;
407 let ys = (ys + residual)?;
408 let residual = &ys;
409 let ys = ys.apply(&layer.ffn_norm)?;
410 let ys = layer.mlp.forward(&ys)?;
411 xs = (ys + residual)?
412 }
413 let xs = xs.apply(&self.output_norm)?.i((.., seq_len - 1, ..))?;
414 MatMul.qmatmul(&xs, &self.output)
415 }
416}