1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::sync::Arc;
4
5use candle_core::quantized::QMatMul;
6use candle_core::quantized::QTensor;
7use candle_core::{DType, Device, IndexOp, Module, Result, Tensor, D};
8use candle_nn::{Embedding, LayerNorm};
9use mistralrs_quant::GgufMatMul;
10use mistralrs_quant::QuantMethod;
11use mistralrs_quant::QuantMethodConfig;
12
13use crate::attention::SdpaParams;
14use crate::device_map::DeviceMapper;
15use crate::gguf::Content;
16use crate::layers::MatMul;
17use crate::layers::Sdpa;
18use crate::layers::{CausalMasker, QLinear};
19use crate::layers_masker::PastKvLenCache;
20use crate::paged_attention::AttentionImplementation;
21use crate::paged_attention::PagedAttention;
22use crate::pipeline::extract_logits;
23use crate::pipeline::text_models_inputs_processor::PagedAttentionInputMetadata;
24use crate::pipeline::EitherCache;
25use crate::pipeline::KvCache;
26use crate::pipeline::NormalCache;
27use crate::utils::gguf_metadata::ContentMetadata;
28use crate::utils::model_config as ModelConfig;
29use crate::utils::progress::{new_multi_progress, NiceProgressBar};
30
31pub const DEFAULT_MAX_SEQ_LEN: usize = 4096;
32
33#[derive(Clone)]
34struct Mlp {
35 ffn_up: Arc<dyn QuantMethod>,
36 ffn_down: Arc<dyn QuantMethod>,
37}
38
39impl Module for Mlp {
40 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
41 MatMul.qmethod_matmul(&MatMul.qmethod_matmul(xs, &*self.ffn_up)?, &*self.ffn_down)
42 }
43}
44
45struct LayerWeights {
46 attn_qkv: Arc<dyn QuantMethod>,
47 attn_output: Arc<dyn QuantMethod>,
48 attn_norm: LayerNorm,
49 mlp: Mlp,
50 n_head: usize,
51 head_dim: usize,
52 cos: Tensor,
53 sin: Tensor,
54 rope_dim: usize,
55 paged_attn: Option<PagedAttention>,
56 sdpa_params: SdpaParams,
57 dtype: DType,
58}
59
60impl LayerWeights {
61 fn forward(&self, xs: &Tensor, start_offsets: &[usize]) -> Result<Tensor> {
62 let (_b_sz, _n_head, seq_len, _n_embd) = xs.dims4()?;
63 let xs_rot = xs.i((.., .., .., ..self.rope_dim))?;
64 let xs_pass = xs.i((.., .., .., self.rope_dim..))?;
65 let mut chunks = Vec::new();
66 for (b, offset) in (0..xs.dim(0)?).zip(start_offsets) {
67 let cos = self.cos.narrow(0, *offset, seq_len)?;
68 let sin = self.sin.narrow(0, *offset, seq_len)?;
69 let xs_rot =
70 candle_nn::rotary_emb::rope(&xs_rot.i(b)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
71 chunks.push(Tensor::cat(&[&xs_rot, &xs_pass], D::Minus1)?);
72 }
73 Tensor::cat(&chunks, 0)?.contiguous()
74 }
75
76 fn forward_attn(
77 &self,
78 x: &Tensor,
79 mask: Option<&Tensor>,
80 seqlen_offsets: &[usize],
81 kv_cache: &mut KvCache,
82 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
83 ) -> Result<Tensor> {
84 let (b_sz, seq_len, _) = x.dims3()?;
85 let qkv = self
86 .attn_qkv
87 .forward(x)?
88 .reshape((b_sz, seq_len, 3, self.n_head, self.head_dim))?
89 .to_dtype(self.dtype)?;
90
91 let q = qkv.i((.., .., 0))?.transpose(1, 2)?;
92 let k = qkv.i((.., .., 1))?.transpose(1, 2)?;
93 let v = qkv.i((.., .., 2))?.transpose(1, 2)?;
94 let v = v.contiguous()?;
98
99 let q = self.forward(&q, seqlen_offsets)?.contiguous()?;
100 let k = self.forward(&k, seqlen_offsets)?;
101
102 let y = match &self.paged_attn {
103 Some(paged_attn) => {
104 let ((key_cache, value_cache), input_metadata) = metadata.unwrap();
105 paged_attn.forward(
106 &q,
107 &k,
108 &v,
109 mask,
110 Some(key_cache),
111 Some(value_cache),
112 input_metadata,
113 &self.sdpa_params,
114 None,
115 )?
116 }
117 None => {
118 let (k, v) = kv_cache.append(&k, &v)?;
119
120 Sdpa.run_attention(&q, &k, &v, mask, None, &self.sdpa_params)?
121 }
122 };
123
124 let y = if mask.is_some() {
125 y.transpose(1, 2)?.reshape((b_sz, seq_len, ()))?
126 } else {
127 y.reshape((b_sz, seq_len, ()))?
128 };
129 let y = self.attn_output.forward(&y.to_dtype(x.dtype())?)?;
130 Ok(y)
131 }
132}
133
134pub struct ModelWeights {
135 tok_embeddings: Embedding,
136 layers: Vec<LayerWeights>,
137 output_norm: LayerNorm,
138 output: QLinear,
139 pub device: Device,
140 pub cache: EitherCache,
141 pub max_seq_len: usize,
142 mapper: Box<dyn DeviceMapper + Send + Sync>,
143 dtype: DType,
144}
145
146fn precomput_freqs_cis(
147 head_dim: usize,
148 freq_base: f32,
149 device: &Device,
150 max_seq_len: usize,
151 dtype: DType,
152) -> Result<(Tensor, Tensor)> {
153 let theta: Vec<_> = (0..head_dim)
154 .step_by(2)
155 .map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32))
156 .collect();
157 let theta = Tensor::new(theta.as_slice(), device)?;
158 let idx_theta = Tensor::arange(0, max_seq_len as u32, device)?
159 .to_dtype(DType::F32)?
160 .reshape((max_seq_len, 1))?
161 .matmul(&theta.reshape((1, theta.elem_count()))?)?;
162 let cos = idx_theta.cos()?.to_dtype(dtype)?;
163 let sin = idx_theta.sin()?.to_dtype(dtype)?;
164 Ok((cos, sin))
165}
166
167fn layer_norm(w: QTensor, b: QTensor, eps: f64) -> Result<LayerNorm> {
168 let w = w.dequantize(&w.device())?;
169 let b = b.dequantize(&b.device())?;
170 let ln = LayerNorm::new(w, b, eps);
171 Ok(ln)
172}
173
174struct PropsGGUF {
178 head_count: usize,
179 head_count_kv: usize,
180 block_count: usize,
181 embedding_length: usize,
182 rope_dim: usize,
183 ln_eps: f64,
184 max_seq_len: usize,
185}
186
187impl TryFrom<ContentMetadata<'_>> for PropsGGUF {
188 type Error = anyhow::Error;
189
190 fn try_from(c: ContentMetadata) -> std::result::Result<Self, Self::Error> {
191 c.verify_arch("phi2")?;
192
193 let required = [
194 "attention.head_count",
195 "attention.head_count_kv",
196 "block_count",
197 "embedding_length",
198 "rope.dimension_count",
199 "attention.layer_norm_rms_epsilon",
200 "context_length",
201 ];
202 c.has_required_keys(&required)?;
203
204 let props = Self {
207 head_count: c.get_value::<u32>("attention.head_count")? as usize,
208 head_count_kv: c.get_value::<u32>("attention.head_count_kv")? as usize,
209 block_count: c.get_value::<u32>("block_count")? as usize,
210 embedding_length: c.get_value::<u32>("embedding_length")? as usize,
211 rope_dim: c.get_value::<u32>("rope.dimension_count")? as usize,
212 ln_eps: c.get_value::<f32>("attention.layer_norm_rms_epsilon")? as f64,
213 max_seq_len: c
214 .get_value::<u64>("context_length")
215 .ok()
216 .unwrap_or(DEFAULT_MAX_SEQ_LEN as u64) as usize,
217 };
218
219 Ok(props)
220 }
221}
222
223impl ModelConfig::FromGGUF for ModelWeights {
224 fn from_gguf<R: std::io::Seek + std::io::Read>(
225 mut ct: Content<'_, R>,
226 device: &Device,
227 mapper: Box<dyn DeviceMapper + Send + Sync>,
228 attention_mechanism: AttentionImplementation,
229 dtype: DType,
230 ) -> Result<Self> {
231 let metadata = ContentMetadata {
233 path_prefix: "phi2",
234 metadata: ct.get_metadata(),
235 };
236 let PropsGGUF {
237 head_count,
238 head_count_kv,
239 block_count,
240 embedding_length,
241 rope_dim,
242 ln_eps,
243 max_seq_len,
244 } = PropsGGUF::try_from(metadata).or_else(|err| candle_core::bail!("{err}"))?;
245
246 let (cos, sin) = precomput_freqs_cis(rope_dim, 10_000., device, max_seq_len, dtype)?;
247
248 let tok_embeddings = ct.tensor("token_embd.weight", device)?;
249 let tok_embeddings = tok_embeddings.dequantize(device)?;
250 let output_norm = layer_norm(
251 ct.tensor("output_norm.weight", device)?,
252 ct.tensor("output_norm.bias", device)?,
253 ln_eps,
254 )?;
255 let output = QLinear::new(&mut ct, "output", device)?;
256 let mut layers = Vec::with_capacity(block_count);
257 let head_dim = embedding_length / head_count;
258
259 for layer_idx in NiceProgressBar::<_, 'b'>(
260 0..block_count,
261 "Loading repeating layers",
262 &new_multi_progress(),
263 ) {
264 let prefix = format!("blk.{layer_idx}");
265 let device = mapper.device_for(layer_idx, false).unwrap_or(device);
266
267 let ffn_up = QLinear::new(&mut ct, &format!("{prefix}.ffn_up"), device)?;
268 let ffn_down = QLinear::new(&mut ct, &format!("{prefix}.ffn_down"), device)?;
269 let QMatMul::QTensor(ffn_up_w) = ffn_up.inner_ref().clone() else {
270 unreachable!()
271 };
272 let QMatMul::QTensor(ffn_down_w) = ffn_down.inner_ref().clone() else {
273 unreachable!()
274 };
275 let mlp = Mlp {
276 ffn_up: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
277 q_weight: ffn_up_w,
278 b: ffn_up.bias().cloned(),
279 })?),
280 ffn_down: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
281 q_weight: ffn_down_w,
282 b: ffn_down.bias().cloned(),
283 })?),
284 };
285 let attn_norm = layer_norm(
286 ct.tensor(&format!("{prefix}.attn_norm.weight"), device)?,
287 ct.tensor(&format!("{prefix}.attn_norm.bias"), device)?,
288 ln_eps,
289 )?;
290 let paged_attn = match &attention_mechanism {
291 AttentionImplementation::Eager => None,
292 AttentionImplementation::PagedAttention => {
293 Some(PagedAttention::new(head_dim, device, None)?)
294 }
295 };
296 let qkv = QLinear::new(&mut ct, &format!("{prefix}.attn_qkv"), device)?;
297 let out = QLinear::new(&mut ct, &format!("{prefix}.attn_output"), device)?;
298 let QMatMul::QTensor(qkv_w) = qkv.inner_ref().clone() else {
299 unreachable!()
300 };
301 let QMatMul::QTensor(out_w) = out.inner_ref().clone() else {
302 unreachable!()
303 };
304 layers.push(LayerWeights {
305 attn_qkv: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
306 q_weight: qkv_w,
307 b: qkv.bias().cloned(),
308 })?),
309 attn_output: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
310 q_weight: out_w,
311 b: out.bias().cloned(),
312 })?),
313 attn_norm,
314 mlp,
315 n_head: head_count,
316 head_dim,
317 cos: cos.clone().to_device(device)?,
318 sin: sin.clone().to_device(device)?,
319 rope_dim,
320 paged_attn,
321 sdpa_params: SdpaParams {
322 n_kv_groups: head_count / head_count_kv,
323 softcap: None,
324 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
325 sliding_window: None,
326 },
327 dtype,
328 })
329 }
330 Ok(Self {
331 tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
332 layers,
333 output_norm,
334 output,
335 device: device.clone(),
336 cache: EitherCache::Normal(NormalCache::new(block_count, max_seq_len)),
337 max_seq_len,
338 mapper,
339 dtype,
340 })
341 }
342}
343
344impl ModelWeights {
345 pub fn forward(
346 &self,
347 input_ids: &Tensor,
348 seqlen_offsets: &[usize],
349 context_lens: Vec<(usize, usize)>,
350 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
351 ) -> Result<Tensor> {
352 let mut xs = self.tok_embeddings.forward(input_ids)?;
353 let cache = &mut self.cache.normal().0;
354 let mask = CausalMasker.make_causal_mask_matrix(
355 input_ids,
356 metadata
357 .as_ref()
358 .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
359 .unwrap_or(cache as &dyn PastKvLenCache),
360 self.dtype,
361 self.layers[0].n_head,
362 )?;
363 let mask = mask.filter(|_| {
364 metadata
365 .as_ref()
366 .map(|(_, meta)| meta.is_first_prompt_chunk)
367 .unwrap_or(true)
368 });
369 for (i, layer) in self.layers.iter().enumerate() {
370 xs = self.mapper.map(xs, i)?;
371 let residual = &xs;
372 let xs_norm = xs.apply(&layer.attn_norm)?;
373 let attn_outputs = layer.forward_attn(
374 &xs_norm,
375 mask.as_ref()
376 .map(|m| m.to_device(xs.device()).unwrap())
377 .as_ref(),
378 seqlen_offsets,
379 &mut cache[i],
380 metadata
381 .as_ref()
382 .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
383 )?;
384 let feed_forward_hidden_states = layer.mlp.forward(&xs_norm)?;
385 xs = (attn_outputs + feed_forward_hidden_states + residual)?
386 }
387 let xs = xs.to_device(&self.device)?;
388 let xs = extract_logits(&xs.apply(&self.output_norm)?, context_lens)?;
389 self.output.forward(&xs)
390 }
391}