1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::sync::Arc;
4
5use crate::attention::SdpaParams;
6use crate::device_map::DeviceMapper;
7use crate::gguf::Content;
8use crate::layers::{CausalMasker, MatMul, RmsNorm, Sdpa};
9use crate::layers_masker::PastKvLenCache;
10use crate::paged_attention::{AttentionImplementation, PagedAttention};
11use crate::pipeline::text_models_inputs_processor::PagedAttentionInputMetadata;
12use crate::pipeline::{EitherCache, KvCache, NormalCache};
13use crate::utils::gguf_metadata::ContentMetadata;
14use crate::utils::model_config as ModelConfig;
15use crate::utils::progress::{new_multi_progress, NiceProgressBar};
16use candle_core::quantized::QMatMul;
17use candle_core::quantized::QTensor;
18use candle_core::{DType, Device, IndexOp, Module, Result, Tensor, D};
19use candle_nn::Embedding;
20use mistralrs_quant::{GgufMatMul, QuantMethod, QuantMethodConfig};
21
22#[derive(Clone)]
23struct Mlp {
24 ffn_up: Arc<dyn QuantMethod>,
25 ffn_down: Arc<dyn QuantMethod>,
26 i_size: usize,
27}
28
29impl Module for Mlp {
30 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
31 let up_states = MatMul.qmethod_matmul(xs, &*self.ffn_up)?;
32 let gate = up_states.narrow(D::Minus1, 0, self.i_size)?;
33 let up_states = up_states.narrow(D::Minus1, self.i_size, self.i_size)?;
34 let up_states = (up_states * gate.silu()?)?;
35 MatMul.qmethod_matmul(&up_states, &*self.ffn_down)
36 }
37}
38
39fn rms_norm(w: QTensor, eps: f64) -> Result<RmsNorm> {
40 let w = w.dequantize(&w.device())?;
41 let rms = RmsNorm::from_w(w, eps)?;
42 Ok(rms)
43}
44
45struct LayerWeights {
46 attn_qkv: Arc<dyn QuantMethod>,
47 attn_output: Arc<dyn QuantMethod>,
48 attn_norm: RmsNorm,
49 ffn_norm: RmsNorm,
50 mlp: Mlp,
51 n_head: usize,
52 n_kv_head: usize,
53 head_dim: usize,
54 cos: Tensor,
55 sin: Tensor,
56 paged_attn: Option<PagedAttention>,
57 sdpa_params: SdpaParams,
58 dtype: DType,
59}
60
61impl LayerWeights {
62 fn apply_rotary_emb(&self, xs: &Tensor, seqlen_offsets: &[usize]) -> Result<Tensor> {
63 let (_b_sz, _h, seq_len, _n_embd) = xs.dims4()?;
64 let mut outputs = Vec::new();
65 for (i, offset) in seqlen_offsets.iter().enumerate() {
66 let cos = self.cos.narrow(0, *offset, seq_len)?;
67 let sin = self.sin.narrow(0, *offset, seq_len)?;
68 outputs.push(candle_nn::rotary_emb::rope(
69 &xs.i(i)?.unsqueeze(0)?.contiguous()?,
70 &cos,
71 &sin,
72 )?);
73 }
74 Tensor::cat(&outputs, 0)
75 }
76
77 fn forward_attn(
78 &self,
79 x: &Tensor,
80 mask: Option<&Tensor>,
81 seqlen_offsets: &[usize],
82 kv_cache: &mut KvCache,
83 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
84 ) -> Result<Tensor> {
85 let (b_sz, seq_len, _) = x.dims3()?;
86 let qkv = MatMul
87 .qmethod_matmul(x, &*self.attn_qkv)?
88 .to_dtype(self.dtype)?;
89 let query_pos = self.n_head * self.head_dim;
90 let q = qkv.narrow(D::Minus1, 0, query_pos)?;
91 let k = qkv.narrow(D::Minus1, query_pos, self.n_kv_head * self.head_dim)?;
92 let v = qkv.narrow(
93 D::Minus1,
94 query_pos + self.n_kv_head * self.head_dim,
95 self.n_kv_head * self.head_dim,
96 )?;
97
98 let (q, k, v) = if seq_len != 1 {
99 let q = q
100 .reshape((b_sz, seq_len, self.n_head, self.head_dim))?
101 .transpose(1, 2)?;
102 let k = k
103 .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
104 .transpose(1, 2)?;
105 let v = v
106 .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
107 .transpose(1, 2)?;
108 (q, k, v.contiguous()?)
109 } else {
110 let q = q.reshape((b_sz, self.n_head, seq_len, self.head_dim))?;
111 let k = k.reshape((b_sz, self.n_kv_head, seq_len, self.head_dim))?;
112 let v = v.reshape((b_sz, self.n_kv_head, seq_len, self.head_dim))?;
113 (q, k, v)
114 };
115
116 let q = self.apply_rotary_emb(&q, seqlen_offsets)?;
117 let k = self.apply_rotary_emb(&k, seqlen_offsets)?;
118 let y = match &self.paged_attn {
119 Some(paged_attn) => {
120 let ((key_cache, value_cache), input_metadata) = metadata.unwrap();
121 paged_attn.forward(
122 &q,
123 &k,
124 &v,
125 mask,
126 Some(key_cache),
127 Some(value_cache),
128 input_metadata,
129 &self.sdpa_params,
130 None,
131 )?
132 }
133 None => {
134 let (k, v) = kv_cache.append(&k, &v)?;
135
136 Sdpa.run_attention(&q, &k, &v, mask, None, &self.sdpa_params)?
137 }
138 };
139
140 let y = if mask.is_some() {
141 y.transpose(1, 2)?.reshape((b_sz, seq_len, ()))?
142 } else {
143 y.reshape((b_sz, seq_len, ()))?
144 };
145 let y = MatMul.qmethod_matmul(&y.to_dtype(x.dtype())?, &*self.attn_output)?;
146 Ok(y)
147 }
148}
149
150pub struct ModelWeights {
151 tok_embeddings: Embedding,
152 layers: Vec<LayerWeights>,
153 output_norm: RmsNorm,
154 output: QMatMul,
155 mapper: Option<Box<dyn DeviceMapper + Send + Sync>>,
156 pub device: Device,
157 pub cache: EitherCache,
158 pub max_seq_len: usize,
159 dtype: DType,
160}
161
162fn precomput_freqs_cis(
163 head_dim: usize,
164 freq_base: f32,
165 device: &Device,
166 context_window: usize,
167 dtype: DType,
168) -> Result<(Tensor, Tensor)> {
169 let theta: Vec<_> = (0..head_dim)
170 .step_by(2)
171 .map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32))
172 .collect();
173 let theta = Tensor::new(theta.as_slice(), device)?;
174 let idx_theta = Tensor::arange(0, context_window as u32, device)?
175 .to_dtype(DType::F32)?
176 .reshape((context_window, 1))?
177 .matmul(&theta.reshape((1, theta.elem_count()))?)?;
178 let cos = idx_theta.cos()?.to_dtype(dtype)?;
179 let sin = idx_theta.sin()?.to_dtype(dtype)?;
180 Ok((cos, sin))
181}
182
183pub(crate) struct PropsGGUF {
187 pub head_count: usize,
188 pub head_count_kv: usize,
189 pub block_count: usize,
190 pub embedding_length: usize,
191 pub i_size: usize,
192 pub rope_dim: usize,
193 pub rms_eps: f64,
194 pub context_window: usize,
195}
196
197impl TryFrom<ContentMetadata<'_>> for PropsGGUF {
198 type Error = anyhow::Error;
199
200 fn try_from(c: ContentMetadata) -> std::result::Result<Self, Self::Error> {
201 c.verify_arch("phi3")?;
202
203 let required = [
204 "attention.head_count",
205 "attention.head_count_kv",
206 "block_count",
207 "embedding_length",
208 "feed_forward_length",
209 "rope.dimension_count",
210 "attention.layer_norm_rms_epsilon",
211 "context_length",
212 ];
213 c.has_required_keys(&required)?;
214
215 let props = Self {
218 head_count: c.get_value::<u32>("attention.head_count")? as usize,
219 head_count_kv: c.get_value::<u32>("attention.head_count_kv")? as usize,
220 block_count: c.get_value::<u32>("block_count")? as usize,
221 embedding_length: c.get_value::<u32>("embedding_length")? as usize,
222 i_size: c.get_value::<u32>("feed_forward_length")? as usize,
223 rope_dim: c.get_value::<u32>("rope.dimension_count")? as usize,
224 rms_eps: c.get_value::<f32>("attention.layer_norm_rms_epsilon")? as f64,
225 context_window: c.get_value::<u32>("context_length")? as usize,
226 };
227
228 Ok(props)
229 }
230}
231
232impl ModelConfig::FromGGUF for ModelWeights {
233 fn from_gguf<R: std::io::Seek + std::io::Read>(
234 mut ct: Content<'_, R>,
235 device: &Device,
236 mapper: Box<dyn DeviceMapper + Send + Sync>,
237 attention_mechanism: AttentionImplementation,
238 dtype: DType,
239 ) -> Result<Self> {
240 let metadata = ContentMetadata {
242 path_prefix: "phi3",
243 metadata: ct.get_metadata(),
244 };
245 let PropsGGUF {
246 head_count,
247 head_count_kv,
248 block_count,
249 embedding_length,
250 i_size,
251 rope_dim,
252 rms_eps,
253 context_window,
254 } = PropsGGUF::try_from(metadata).or_else(|err| candle_core::bail!("{err}"))?;
255
256 let (cos, sin) = precomput_freqs_cis(rope_dim, 10_000., device, context_window, dtype)?;
257
258 let tok_embeddings = ct.tensor("token_embd.weight", device)?;
259 let tok_embeddings = tok_embeddings.dequantize(device)?;
260 let output_norm = rms_norm(ct.tensor("output_norm.weight", device)?, rms_eps)?;
261 let output = QMatMul::from_qtensor(ct.tensor("output.weight", device)?)?;
262 let mut layers = Vec::with_capacity(block_count);
263 let head_dim = embedding_length / head_count;
264
265 for layer_idx in NiceProgressBar::<_, 'b'>(
266 0..block_count,
267 "Loading repeating layers",
268 &new_multi_progress(),
269 ) {
270 let prefix = format!("blk.{layer_idx}");
271 let device = mapper.device_for(layer_idx, false).unwrap_or(device);
272 let ffn_up =
273 QMatMul::from_qtensor(ct.tensor(&format!("{prefix}.ffn_up.weight"), device)?)?;
274 let ffn_down =
275 QMatMul::from_qtensor(ct.tensor(&format!("{prefix}.ffn_down.weight"), device)?)?;
276 let QMatMul::QTensor(ffn_up_w) = ffn_up else {
277 unreachable!()
278 };
279 let QMatMul::QTensor(ffn_down_w) = ffn_down else {
280 unreachable!()
281 };
282 let mlp = Mlp {
283 ffn_up: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
284 q_weight: ffn_up_w,
285 b: None,
286 })?),
287 ffn_down: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
288 q_weight: ffn_down_w,
289 b: None,
290 })?),
291 i_size,
292 };
293 let attn_norm = rms_norm(
294 ct.tensor(&format!("{prefix}.attn_norm.weight"), device)?,
295 rms_eps,
296 )?;
297 let ffn_norm = rms_norm(
298 ct.tensor(&format!("{prefix}.ffn_norm.weight"), device)?,
299 rms_eps,
300 )?;
301 let paged_attn = match &attention_mechanism {
302 AttentionImplementation::Eager => None,
303 AttentionImplementation::PagedAttention => {
304 Some(PagedAttention::new(head_dim, device, None)?)
305 }
306 };
307 let qkv =
308 QMatMul::from_qtensor(ct.tensor(&format!("{prefix}.attn_qkv.weight"), device)?)?;
309 let out =
310 QMatMul::from_qtensor(ct.tensor(&format!("{prefix}.attn_output.weight"), device)?)?;
311 let QMatMul::QTensor(qkv_w) = qkv.clone() else {
312 unreachable!()
313 };
314 let QMatMul::QTensor(out_w) = out.clone() else {
315 unreachable!()
316 };
317 layers.push(LayerWeights {
318 attn_qkv: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
319 q_weight: qkv_w,
320 b: None,
321 })?),
322 attn_output: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
323 q_weight: out_w,
324 b: None,
325 })?),
326 attn_norm,
327 ffn_norm,
328 mlp,
329 n_head: head_count,
330 n_kv_head: head_count_kv,
331 head_dim,
332 cos: cos.to_device(device)?,
333 sin: sin.to_device(device)?,
334 paged_attn,
335 sdpa_params: SdpaParams {
336 n_kv_groups: head_count / head_count_kv,
337 softcap: None,
338 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
339 sliding_window: Some(context_window),
340 },
341 dtype,
342 })
343 }
344 Ok(Self {
345 tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
346 layers,
347 output_norm,
348 output,
349 mapper: Some(mapper),
350 device: device.clone(),
351 cache: EitherCache::Normal(NormalCache::new_sliding(
352 block_count,
353 context_window,
354 Some(context_window),
355 )),
356 max_seq_len: context_window,
357 dtype,
358 })
359 }
360}
361
362impl ModelWeights {
363 pub fn forward(
364 &self,
365 input_ids: &Tensor,
366 seqlen_offsets: &[usize],
367 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
368 ) -> Result<Tensor> {
369 let (_b_sz, seq_len) = input_ids.dims2()?;
370 let mut xs = self.tok_embeddings.forward(input_ids)?;
371 let cache = &mut self.cache.normal().0;
372 let mask = CausalMasker.make_sliding_window_causal_mask_matrix(
373 input_ids,
374 metadata
375 .as_ref()
376 .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
377 .unwrap_or(cache as &dyn PastKvLenCache),
378 Some(self.max_seq_len),
379 self.dtype,
380 self.layers[0].n_head,
381 )?;
382 let mask = mask.filter(|_| {
383 metadata
384 .as_ref()
385 .map(|(_, meta)| meta.is_first_prompt_chunk)
386 .unwrap_or(true)
387 });
388 for (i, layer) in self.layers.iter().enumerate() {
389 if let Some(ref mapper) = self.mapper {
390 xs = mapper.map(xs, i)?;
391 }
392 let residual = &xs;
393 let ys = xs.apply(&layer.attn_norm)?;
394 let ys = layer.forward_attn(
395 &ys,
396 mask.as_ref()
397 .map(|m| m.to_device(xs.device()).unwrap())
398 .as_ref(),
399 seqlen_offsets,
400 &mut cache[i],
401 metadata
402 .as_ref()
403 .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
404 )?;
405 let ys = (ys + residual)?;
406 let residual = &ys;
407 let ys = ys.apply(&layer.ffn_norm)?;
408 let ys = layer.mlp.forward(&ys)?;
409 xs = (ys + residual)?
410 }
411 let xs = xs.apply(&self.output_norm)?.i((.., seq_len - 1, ..))?;
412 MatMul.qmatmul(&xs, &self.output)
413 }
414}