1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::sync::Arc;
4
5use candle_core::quantized::QMatMul;
6use candle_core::quantized::QTensor;
7use candle_core::{DType, Device, IndexOp, Module, Result, Tensor, D};
8use candle_nn::{Embedding, LayerNorm};
9use indicatif::MultiProgress;
10use mistralrs_quant::GgufMatMul;
11use mistralrs_quant::QuantMethod;
12use mistralrs_quant::QuantMethodConfig;
13
14use crate::attention::SdpaParams;
15use crate::device_map::DeviceMapper;
16use crate::gguf::Content;
17use crate::layers::MatMul;
18use crate::layers::Sdpa;
19use crate::layers::{CausalMasker, QLinear};
20use crate::layers_masker::PastKvLenCache;
21use crate::paged_attention::AttentionImplementation;
22use crate::paged_attention::PagedAttention;
23use crate::pipeline::extract_logits;
24use crate::pipeline::text_models_inputs_processor::PagedAttentionInputMetadata;
25use crate::pipeline::EitherCache;
26use crate::pipeline::KvCache;
27use crate::pipeline::NormalCache;
28use crate::utils::gguf_metadata::ContentMetadata;
29use crate::utils::model_config as ModelConfig;
30use crate::utils::progress::NiceProgressBar;
31
32pub const MAX_SEQ_LEN: usize = 4096;
33
34#[derive(Clone)]
35struct Mlp {
36 ffn_up: Arc<dyn QuantMethod>,
37 ffn_down: Arc<dyn QuantMethod>,
38}
39
40impl Module for Mlp {
41 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
42 MatMul.qmethod_matmul(&MatMul.qmethod_matmul(xs, &*self.ffn_up)?, &*self.ffn_down)
43 }
44}
45
46struct LayerWeights {
47 attn_qkv: Arc<dyn QuantMethod>,
48 attn_output: Arc<dyn QuantMethod>,
49 attn_norm: LayerNorm,
50 mlp: Mlp,
51 n_head: usize,
52 head_dim: usize,
53 cos: Tensor,
54 sin: Tensor,
55 rope_dim: usize,
56 paged_attn: Option<PagedAttention>,
57 sdpa_params: SdpaParams,
58 dtype: DType,
59}
60
61impl LayerWeights {
62 fn forward(&self, xs: &Tensor, start_offsets: &[usize]) -> Result<Tensor> {
63 let (_b_sz, _n_head, seq_len, _n_embd) = xs.dims4()?;
64 let xs_rot = xs.i((.., .., .., ..self.rope_dim))?;
65 let xs_pass = xs.i((.., .., .., self.rope_dim..))?;
66 let mut chunks = Vec::new();
67 for (b, offset) in (0..xs.dim(0)?).zip(start_offsets) {
68 let cos = self.cos.narrow(0, *offset, seq_len)?;
69 let sin = self.sin.narrow(0, *offset, seq_len)?;
70 let xs_rot =
71 candle_nn::rotary_emb::rope(&xs_rot.i(b)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
72 chunks.push(Tensor::cat(&[&xs_rot, &xs_pass], D::Minus1)?);
73 }
74 Tensor::cat(&chunks, 0)?.contiguous()
75 }
76
77 fn forward_attn(
78 &self,
79 x: &Tensor,
80 mask: Option<&Tensor>,
81 seqlen_offsets: &[usize],
82 kv_cache: &mut KvCache,
83 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
84 ) -> Result<Tensor> {
85 let (b_sz, seq_len, _) = x.dims3()?;
86 let qkv = self
87 .attn_qkv
88 .forward(x)?
89 .reshape((b_sz, seq_len, 3, self.n_head, self.head_dim))?
90 .to_dtype(self.dtype)?;
91
92 let q = qkv.i((.., .., 0))?.transpose(1, 2)?;
93 let k = qkv.i((.., .., 1))?.transpose(1, 2)?;
94 let v = qkv.i((.., .., 2))?.transpose(1, 2)?;
95 let v = v.contiguous()?;
99
100 let q = self.forward(&q, seqlen_offsets)?.contiguous()?;
101 let k = self.forward(&k, seqlen_offsets)?;
102
103 let y = match &self.paged_attn {
104 Some(paged_attn) => {
105 let ((key_cache, value_cache), input_metadata) = metadata.unwrap();
106 paged_attn.forward(
107 &q,
108 &k,
109 &v,
110 mask,
111 Some(key_cache),
112 Some(value_cache),
113 input_metadata,
114 &self.sdpa_params,
115 None,
116 )?
117 }
118 None => {
119 let (k, v) = kv_cache.append(&k, &v)?;
120
121 Sdpa.run_attention(&q, &k, &v, mask, None, &self.sdpa_params)?
122 }
123 };
124
125 let y = if mask.is_some() {
126 y.transpose(1, 2)?.reshape((b_sz, seq_len, ()))?
127 } else {
128 y.reshape((b_sz, seq_len, ()))?
129 };
130 let y = self.attn_output.forward(&y.to_dtype(x.dtype())?)?;
131 Ok(y)
132 }
133}
134
135pub struct ModelWeights {
136 tok_embeddings: Embedding,
137 layers: Vec<LayerWeights>,
138 output_norm: LayerNorm,
139 output: QLinear,
140 pub device: Device,
141 pub cache: EitherCache,
142 pub max_seq_len: usize,
143 mapper: Box<dyn DeviceMapper + Send + Sync>,
144 dtype: DType,
145}
146
147fn precomput_freqs_cis(
148 head_dim: usize,
149 freq_base: f32,
150 device: &Device,
151 max_seq_len: usize,
152 dtype: DType,
153) -> Result<(Tensor, Tensor)> {
154 let theta: Vec<_> = (0..head_dim)
155 .step_by(2)
156 .map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32))
157 .collect();
158 let theta = Tensor::new(theta.as_slice(), device)?;
159 let idx_theta = Tensor::arange(0, max_seq_len as u32, device)?
160 .to_dtype(DType::F32)?
161 .reshape((max_seq_len, 1))?
162 .matmul(&theta.reshape((1, theta.elem_count()))?)?;
163 let cos = idx_theta.cos()?.to_dtype(dtype)?;
164 let sin = idx_theta.sin()?.to_dtype(dtype)?;
165 Ok((cos, sin))
166}
167
168fn layer_norm(w: QTensor, b: QTensor, eps: f64) -> Result<LayerNorm> {
169 let w = w.dequantize(&w.device())?;
170 let b = b.dequantize(&b.device())?;
171 let ln = LayerNorm::new(w, b, eps);
172 Ok(ln)
173}
174
175struct PropsGGUF {
179 head_count: usize,
180 head_count_kv: usize,
181 block_count: usize,
182 embedding_length: usize,
183 rope_dim: usize,
184 ln_eps: f64,
185 max_seq_len: usize,
186}
187
188impl TryFrom<ContentMetadata<'_>> for PropsGGUF {
189 type Error = anyhow::Error;
190
191 fn try_from(c: ContentMetadata) -> std::result::Result<Self, Self::Error> {
192 c.verify_arch("phi2")?;
193
194 let required = [
195 "attention.head_count",
196 "attention.head_count_kv",
197 "block_count",
198 "embedding_length",
199 "rope.dimension_count",
200 "attention.layer_norm_rms_epsilon",
201 "context_length",
202 ];
203 c.has_required_keys(&required)?;
204
205 let props = Self {
208 head_count: c.get_value::<u32>("attention.head_count")? as usize,
209 head_count_kv: c.get_value::<u32>("attention.head_count_kv")? as usize,
210 block_count: c.get_value::<u32>("block_count")? as usize,
211 embedding_length: c.get_value::<u32>("embedding_length")? as usize,
212 rope_dim: c.get_value::<u32>("rope.dimension_count")? as usize,
213 ln_eps: c.get_value::<f32>("attention.layer_norm_rms_epsilon")? as f64,
214 max_seq_len: c
215 .get_value::<u64>("context_length")
216 .ok()
217 .unwrap_or(MAX_SEQ_LEN as u64) as usize,
218 };
219
220 Ok(props)
221 }
222}
223
224impl ModelConfig::FromGGUF for ModelWeights {
225 fn from_gguf<R: std::io::Seek + std::io::Read>(
226 mut ct: Content<'_, R>,
227 device: &Device,
228 mapper: Box<dyn DeviceMapper + Send + Sync>,
229 attention_mechanism: AttentionImplementation,
230 dtype: DType,
231 ) -> Result<Self> {
232 let metadata = ContentMetadata {
234 path_prefix: "phi2",
235 metadata: ct.get_metadata(),
236 };
237 let PropsGGUF {
238 head_count,
239 head_count_kv,
240 block_count,
241 embedding_length,
242 rope_dim,
243 ln_eps,
244 max_seq_len,
245 } = PropsGGUF::try_from(metadata).or_else(|err| candle_core::bail!("{err}"))?;
246
247 let (cos, sin) = precomput_freqs_cis(rope_dim, 10_000., device, max_seq_len, dtype)?;
248
249 let tok_embeddings = ct.tensor("token_embd.weight", device)?;
250 let tok_embeddings = tok_embeddings.dequantize(device)?;
251 let output_norm = layer_norm(
252 ct.tensor("output_norm.weight", device)?,
253 ct.tensor("output_norm.bias", device)?,
254 ln_eps,
255 )?;
256 let output = QLinear::new(&mut ct, "output", device)?;
257 let mut layers = Vec::with_capacity(block_count);
258 let head_dim = embedding_length / head_count;
259
260 for layer_idx in NiceProgressBar::<_, 'b'>(
261 0..block_count,
262 "Loading repeating layers",
263 &MultiProgress::new(),
264 ) {
265 let prefix = format!("blk.{layer_idx}");
266 let device = mapper.device_for(layer_idx, false).unwrap_or(device);
267
268 let ffn_up = QLinear::new(&mut ct, &format!("{prefix}.ffn_up"), device)?;
269 let ffn_down = QLinear::new(&mut ct, &format!("{prefix}.ffn_down"), device)?;
270 let QMatMul::QTensor(ffn_up_w) = ffn_up.inner_ref().clone() else {
271 unreachable!()
272 };
273 let QMatMul::QTensor(ffn_down_w) = ffn_down.inner_ref().clone() else {
274 unreachable!()
275 };
276 let mlp = Mlp {
277 ffn_up: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
278 q_weight: ffn_up_w,
279 b: ffn_up.bias().cloned(),
280 })?),
281 ffn_down: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
282 q_weight: ffn_down_w,
283 b: ffn_down.bias().cloned(),
284 })?),
285 };
286 let attn_norm = layer_norm(
287 ct.tensor(&format!("{prefix}.attn_norm.weight"), device)?,
288 ct.tensor(&format!("{prefix}.attn_norm.bias"), device)?,
289 ln_eps,
290 )?;
291 let paged_attn = match &attention_mechanism {
292 AttentionImplementation::Eager => None,
293 AttentionImplementation::PagedAttention => {
294 Some(PagedAttention::new(head_dim, device, None)?)
295 }
296 };
297 let qkv = QLinear::new(&mut ct, &format!("{prefix}.attn_qkv"), device)?;
298 let out = QLinear::new(&mut ct, &format!("{prefix}.attn_output"), device)?;
299 let QMatMul::QTensor(qkv_w) = qkv.inner_ref().clone() else {
300 unreachable!()
301 };
302 let QMatMul::QTensor(out_w) = out.inner_ref().clone() else {
303 unreachable!()
304 };
305 layers.push(LayerWeights {
306 attn_qkv: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
307 q_weight: qkv_w,
308 b: qkv.bias().cloned(),
309 })?),
310 attn_output: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
311 q_weight: out_w,
312 b: out.bias().cloned(),
313 })?),
314 attn_norm,
315 mlp,
316 n_head: head_count,
317 head_dim,
318 cos: cos.clone().to_device(device)?,
319 sin: sin.clone().to_device(device)?,
320 rope_dim,
321 paged_attn,
322 sdpa_params: SdpaParams {
323 n_kv_groups: head_count / head_count_kv,
324 use_flash_attn: false,
325 softcap: None,
326 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
327 sliding_window: None,
328 },
329 dtype,
330 })
331 }
332 Ok(Self {
333 tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
334 layers,
335 output_norm,
336 output,
337 device: device.clone(),
338 cache: EitherCache::Normal(NormalCache::new(block_count, max_seq_len)),
339 max_seq_len,
340 mapper,
341 dtype,
342 })
343 }
344}
345
346impl ModelWeights {
347 pub fn forward(
348 &self,
349 input_ids: &Tensor,
350 seqlen_offsets: &[usize],
351 context_lens: Vec<(usize, usize)>,
352 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
353 ) -> Result<Tensor> {
354 let mut xs = self.tok_embeddings.forward(input_ids)?;
355 let cache = &mut self.cache.normal().0;
356 let mask = CausalMasker.make_causal_mask_matrix(
357 input_ids,
358 metadata
359 .as_ref()
360 .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
361 .unwrap_or(cache as &dyn PastKvLenCache),
362 self.dtype,
363 self.layers[0].n_head,
364 )?;
365 let mask = mask.filter(|_| {
366 metadata
367 .as_ref()
368 .map(|(_, meta)| meta.is_first_prompt_chunk)
369 .unwrap_or(true)
370 });
371 for (i, layer) in self.layers.iter().enumerate() {
372 xs = self.mapper.map(xs, i)?;
373 let residual = &xs;
374 let xs_norm = xs.apply(&layer.attn_norm)?;
375 let attn_outputs = layer.forward_attn(
376 &xs_norm,
377 mask.as_ref()
378 .map(|m| m.to_device(xs.device()).unwrap())
379 .as_ref(),
380 seqlen_offsets,
381 &mut cache[i],
382 metadata
383 .as_ref()
384 .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
385 )?;
386 let feed_forward_hidden_states = layer.mlp.forward(&xs_norm)?;
387 xs = (attn_outputs + feed_forward_hidden_states + residual)?
388 }
389 let xs = xs.to_device(&self.device)?;
390 let xs = extract_logits(&xs.apply(&self.output_norm)?, context_lens)?;
391 self.output.forward(&xs)
392 }
393}