mistralrs_core/models/
quantized_phi2.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::sync::Arc;
4
5use candle_core::quantized::QMatMul;
6use candle_core::quantized::QTensor;
7use candle_core::{DType, Device, IndexOp, Module, Result, Tensor, D};
8use candle_nn::{Embedding, LayerNorm};
9use indicatif::MultiProgress;
10use mistralrs_quant::GgufMatMul;
11use mistralrs_quant::QuantMethod;
12use mistralrs_quant::QuantMethodConfig;
13
14use crate::attention::SdpaParams;
15use crate::device_map::DeviceMapper;
16use crate::gguf::Content;
17use crate::layers::MatMul;
18use crate::layers::Sdpa;
19use crate::layers::{CausalMasker, QLinear};
20use crate::layers_masker::PastKvLenCache;
21use crate::paged_attention::AttentionImplementation;
22use crate::paged_attention::PagedAttention;
23use crate::pipeline::extract_logits;
24use crate::pipeline::text_models_inputs_processor::PagedAttentionInputMetadata;
25use crate::pipeline::EitherCache;
26use crate::pipeline::KvCache;
27use crate::pipeline::NormalCache;
28use crate::utils::gguf_metadata::ContentMetadata;
29use crate::utils::model_config as ModelConfig;
30use crate::utils::progress::NiceProgressBar;
31
32pub const MAX_SEQ_LEN: usize = 4096;
33
34#[derive(Clone)]
35struct Mlp {
36    ffn_up: Arc<dyn QuantMethod>,
37    ffn_down: Arc<dyn QuantMethod>,
38}
39
40impl Module for Mlp {
41    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
42        MatMul.qmethod_matmul(&MatMul.qmethod_matmul(xs, &*self.ffn_up)?, &*self.ffn_down)
43    }
44}
45
46struct LayerWeights {
47    attn_qkv: Arc<dyn QuantMethod>,
48    attn_output: Arc<dyn QuantMethod>,
49    attn_norm: LayerNorm,
50    mlp: Mlp,
51    n_head: usize,
52    head_dim: usize,
53    cos: Tensor,
54    sin: Tensor,
55    rope_dim: usize,
56    paged_attn: Option<PagedAttention>,
57    sdpa_params: SdpaParams,
58    dtype: DType,
59}
60
61impl LayerWeights {
62    fn forward(&self, xs: &Tensor, start_offsets: &[usize]) -> Result<Tensor> {
63        let (_b_sz, _n_head, seq_len, _n_embd) = xs.dims4()?;
64        let xs_rot = xs.i((.., .., .., ..self.rope_dim))?;
65        let xs_pass = xs.i((.., .., .., self.rope_dim..))?;
66        let mut chunks = Vec::new();
67        for (b, offset) in (0..xs.dim(0)?).zip(start_offsets) {
68            let cos = self.cos.narrow(0, *offset, seq_len)?;
69            let sin = self.sin.narrow(0, *offset, seq_len)?;
70            let xs_rot =
71                candle_nn::rotary_emb::rope(&xs_rot.i(b)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
72            chunks.push(Tensor::cat(&[&xs_rot, &xs_pass], D::Minus1)?);
73        }
74        Tensor::cat(&chunks, 0)?.contiguous()
75    }
76
77    fn forward_attn(
78        &self,
79        x: &Tensor,
80        mask: Option<&Tensor>,
81        seqlen_offsets: &[usize],
82        kv_cache: &mut KvCache,
83        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
84    ) -> Result<Tensor> {
85        let (b_sz, seq_len, _) = x.dims3()?;
86        let qkv = self
87            .attn_qkv
88            .forward(x)?
89            .reshape((b_sz, seq_len, 3, self.n_head, self.head_dim))?
90            .to_dtype(self.dtype)?;
91
92        let q = qkv.i((.., .., 0))?.transpose(1, 2)?;
93        let k = qkv.i((.., .., 1))?.transpose(1, 2)?;
94        let v = qkv.i((.., .., 2))?.transpose(1, 2)?;
95        // This call to contiguous ensures that the fast kernel can be called below. It's
96        // actually a no-op except when processing the initial prompt so has no significant
97        // impact on performance.
98        let v = v.contiguous()?;
99
100        let q = self.forward(&q, seqlen_offsets)?.contiguous()?;
101        let k = self.forward(&k, seqlen_offsets)?;
102
103        let y = match &self.paged_attn {
104            Some(paged_attn) => {
105                let ((key_cache, value_cache), input_metadata) = metadata.unwrap();
106                paged_attn.forward(
107                    &q,
108                    &k,
109                    &v,
110                    mask,
111                    Some(key_cache),
112                    Some(value_cache),
113                    input_metadata,
114                    &self.sdpa_params,
115                    None,
116                )?
117            }
118            None => {
119                let (k, v) = kv_cache.append(&k, &v)?;
120
121                Sdpa.run_attention(&q, &k, &v, mask, None, &self.sdpa_params)?
122            }
123        };
124
125        let y = if mask.is_some() {
126            y.transpose(1, 2)?.reshape((b_sz, seq_len, ()))?
127        } else {
128            y.reshape((b_sz, seq_len, ()))?
129        };
130        let y = self.attn_output.forward(&y.to_dtype(x.dtype())?)?;
131        Ok(y)
132    }
133}
134
135pub struct ModelWeights {
136    tok_embeddings: Embedding,
137    layers: Vec<LayerWeights>,
138    output_norm: LayerNorm,
139    output: QLinear,
140    pub device: Device,
141    pub cache: EitherCache,
142    pub max_seq_len: usize,
143    mapper: Box<dyn DeviceMapper + Send + Sync>,
144    dtype: DType,
145}
146
147fn precomput_freqs_cis(
148    head_dim: usize,
149    freq_base: f32,
150    device: &Device,
151    max_seq_len: usize,
152    dtype: DType,
153) -> Result<(Tensor, Tensor)> {
154    let theta: Vec<_> = (0..head_dim)
155        .step_by(2)
156        .map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32))
157        .collect();
158    let theta = Tensor::new(theta.as_slice(), device)?;
159    let idx_theta = Tensor::arange(0, max_seq_len as u32, device)?
160        .to_dtype(DType::F32)?
161        .reshape((max_seq_len, 1))?
162        .matmul(&theta.reshape((1, theta.elem_count()))?)?;
163    let cos = idx_theta.cos()?.to_dtype(dtype)?;
164    let sin = idx_theta.sin()?.to_dtype(dtype)?;
165    Ok((cos, sin))
166}
167
168fn layer_norm(w: QTensor, b: QTensor, eps: f64) -> Result<LayerNorm> {
169    let w = w.dequantize(&w.device())?;
170    let b = b.dequantize(&b.device())?;
171    let ln = LayerNorm::new(w, b, eps);
172    Ok(ln)
173}
174
175// phi2 `llm` fields:
176// https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#llm
177// NOTE: Types here do not match spec
178struct PropsGGUF {
179    head_count: usize,
180    head_count_kv: usize,
181    block_count: usize,
182    embedding_length: usize,
183    rope_dim: usize,
184    ln_eps: f64,
185    max_seq_len: usize,
186}
187
188impl TryFrom<ContentMetadata<'_>> for PropsGGUF {
189    type Error = anyhow::Error;
190
191    fn try_from(c: ContentMetadata) -> std::result::Result<Self, Self::Error> {
192        c.verify_arch("phi2")?;
193
194        let required = [
195            "attention.head_count",
196            "attention.head_count_kv",
197            "block_count",
198            "embedding_length",
199            "rope.dimension_count",
200            "attention.layer_norm_rms_epsilon",
201            "context_length",
202        ];
203        c.has_required_keys(&required)?;
204
205        // NOTE: Values are not aligned with GGUFv3 types
206        // TODO: Normalize value types to spec
207        let props = Self {
208            head_count: c.get_value::<u32>("attention.head_count")? as usize,
209            head_count_kv: c.get_value::<u32>("attention.head_count_kv")? as usize,
210            block_count: c.get_value::<u32>("block_count")? as usize,
211            embedding_length: c.get_value::<u32>("embedding_length")? as usize,
212            rope_dim: c.get_value::<u32>("rope.dimension_count")? as usize,
213            ln_eps: c.get_value::<f32>("attention.layer_norm_rms_epsilon")? as f64,
214            max_seq_len: c
215                .get_value::<u64>("context_length")
216                .ok()
217                .unwrap_or(MAX_SEQ_LEN as u64) as usize,
218        };
219
220        Ok(props)
221    }
222}
223
224impl ModelConfig::FromGGUF for ModelWeights {
225    fn from_gguf<R: std::io::Seek + std::io::Read>(
226        mut ct: Content<'_, R>,
227        device: &Device,
228        mapper: Box<dyn DeviceMapper + Send + Sync>,
229        attention_mechanism: AttentionImplementation,
230        dtype: DType,
231    ) -> Result<Self> {
232        // Parameter extraction from metadata.
233        let metadata = ContentMetadata {
234            path_prefix: "phi2",
235            metadata: ct.get_metadata(),
236        };
237        let PropsGGUF {
238            head_count,
239            head_count_kv,
240            block_count,
241            embedding_length,
242            rope_dim,
243            ln_eps,
244            max_seq_len,
245        } = PropsGGUF::try_from(metadata).or_else(|err| candle_core::bail!("{err}"))?;
246
247        let (cos, sin) = precomput_freqs_cis(rope_dim, 10_000., device, max_seq_len, dtype)?;
248
249        let tok_embeddings = ct.tensor("token_embd.weight", device)?;
250        let tok_embeddings = tok_embeddings.dequantize(device)?;
251        let output_norm = layer_norm(
252            ct.tensor("output_norm.weight", device)?,
253            ct.tensor("output_norm.bias", device)?,
254            ln_eps,
255        )?;
256        let output = QLinear::new(&mut ct, "output", device)?;
257        let mut layers = Vec::with_capacity(block_count);
258        let head_dim = embedding_length / head_count;
259
260        for layer_idx in NiceProgressBar::<_, 'b'>(
261            0..block_count,
262            "Loading repeating layers",
263            &MultiProgress::new(),
264        ) {
265            let prefix = format!("blk.{layer_idx}");
266            let device = mapper.device_for(layer_idx, false).unwrap_or(device);
267
268            let ffn_up = QLinear::new(&mut ct, &format!("{prefix}.ffn_up"), device)?;
269            let ffn_down = QLinear::new(&mut ct, &format!("{prefix}.ffn_down"), device)?;
270            let QMatMul::QTensor(ffn_up_w) = ffn_up.inner_ref().clone() else {
271                unreachable!()
272            };
273            let QMatMul::QTensor(ffn_down_w) = ffn_down.inner_ref().clone() else {
274                unreachable!()
275            };
276            let mlp = Mlp {
277                ffn_up: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
278                    q_weight: ffn_up_w,
279                    b: ffn_up.bias().cloned(),
280                })?),
281                ffn_down: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
282                    q_weight: ffn_down_w,
283                    b: ffn_down.bias().cloned(),
284                })?),
285            };
286            let attn_norm = layer_norm(
287                ct.tensor(&format!("{prefix}.attn_norm.weight"), device)?,
288                ct.tensor(&format!("{prefix}.attn_norm.bias"), device)?,
289                ln_eps,
290            )?;
291            let paged_attn = match &attention_mechanism {
292                AttentionImplementation::Eager => None,
293                AttentionImplementation::PagedAttention => {
294                    Some(PagedAttention::new(head_dim, device, None)?)
295                }
296            };
297            let qkv = QLinear::new(&mut ct, &format!("{prefix}.attn_qkv"), device)?;
298            let out = QLinear::new(&mut ct, &format!("{prefix}.attn_output"), device)?;
299            let QMatMul::QTensor(qkv_w) = qkv.inner_ref().clone() else {
300                unreachable!()
301            };
302            let QMatMul::QTensor(out_w) = out.inner_ref().clone() else {
303                unreachable!()
304            };
305            layers.push(LayerWeights {
306                attn_qkv: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
307                    q_weight: qkv_w,
308                    b: qkv.bias().cloned(),
309                })?),
310                attn_output: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
311                    q_weight: out_w,
312                    b: out.bias().cloned(),
313                })?),
314                attn_norm,
315                mlp,
316                n_head: head_count,
317                head_dim,
318                cos: cos.clone().to_device(device)?,
319                sin: sin.clone().to_device(device)?,
320                rope_dim,
321                paged_attn,
322                sdpa_params: SdpaParams {
323                    n_kv_groups: head_count / head_count_kv,
324                    use_flash_attn: false,
325                    softcap: None,
326                    softmax_scale: 1.0 / (head_dim as f32).sqrt(),
327                    sliding_window: None,
328                },
329                dtype,
330            })
331        }
332        Ok(Self {
333            tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
334            layers,
335            output_norm,
336            output,
337            device: device.clone(),
338            cache: EitherCache::Normal(NormalCache::new(block_count, max_seq_len)),
339            max_seq_len,
340            mapper,
341            dtype,
342        })
343    }
344}
345
346impl ModelWeights {
347    pub fn forward(
348        &self,
349        input_ids: &Tensor,
350        seqlen_offsets: &[usize],
351        context_lens: Vec<(usize, usize)>,
352        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
353    ) -> Result<Tensor> {
354        let mut xs = self.tok_embeddings.forward(input_ids)?;
355        let cache = &mut self.cache.normal().0;
356        let mask = CausalMasker.make_causal_mask_matrix(
357            input_ids,
358            metadata
359                .as_ref()
360                .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
361                .unwrap_or(cache as &dyn PastKvLenCache),
362            self.dtype,
363            self.layers[0].n_head,
364        )?;
365        let mask = mask.filter(|_| {
366            metadata
367                .as_ref()
368                .map(|(_, meta)| meta.is_first_prompt_chunk)
369                .unwrap_or(true)
370        });
371        for (i, layer) in self.layers.iter().enumerate() {
372            xs = self.mapper.map(xs, i)?;
373            let residual = &xs;
374            let xs_norm = xs.apply(&layer.attn_norm)?;
375            let attn_outputs = layer.forward_attn(
376                &xs_norm,
377                mask.as_ref()
378                    .map(|m| m.to_device(xs.device()).unwrap())
379                    .as_ref(),
380                seqlen_offsets,
381                &mut cache[i],
382                metadata
383                    .as_ref()
384                    .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
385            )?;
386            let feed_forward_hidden_states = layer.mlp.forward(&xs_norm)?;
387            xs = (attn_outputs + feed_forward_hidden_states + residual)?
388        }
389        let xs = xs.to_device(&self.device)?;
390        let xs = extract_logits(&xs.apply(&self.output_norm)?, context_lens)?;
391        self.output.forward(&xs)
392    }
393}