mistralrs_core/models/
quantized_phi3.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::sync::Arc;
4
5use crate::attention::SdpaParams;
6use crate::device_map::DeviceMapper;
7use crate::gguf::Content;
8use crate::layers::{CausalMasker, MatMul, RmsNorm, Sdpa};
9use crate::layers_masker::PastKvLenCache;
10use crate::paged_attention::{AttentionImplementation, PagedAttention};
11use crate::pipeline::text_models_inputs_processor::PagedAttentionInputMetadata;
12use crate::pipeline::{EitherCache, KvCache, NormalCache};
13use crate::utils::gguf_metadata::ContentMetadata;
14use crate::utils::model_config as ModelConfig;
15use crate::utils::progress::{new_multi_progress, NiceProgressBar};
16use candle_core::quantized::QMatMul;
17use candle_core::quantized::QTensor;
18use candle_core::{DType, Device, IndexOp, Module, Result, Tensor, D};
19use candle_nn::Embedding;
20use mistralrs_quant::{GgufMatMul, QuantMethod, QuantMethodConfig};
21
22#[derive(Clone)]
23struct Mlp {
24    ffn_up: Arc<dyn QuantMethod>,
25    ffn_down: Arc<dyn QuantMethod>,
26    i_size: usize,
27}
28
29impl Module for Mlp {
30    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
31        let up_states = MatMul.qmethod_matmul(xs, &*self.ffn_up)?;
32        let gate = up_states.narrow(D::Minus1, 0, self.i_size)?;
33        let up_states = up_states.narrow(D::Minus1, self.i_size, self.i_size)?;
34        let up_states = (up_states * gate.silu()?)?;
35        MatMul.qmethod_matmul(&up_states, &*self.ffn_down)
36    }
37}
38
39fn rms_norm(w: QTensor, eps: f64) -> Result<RmsNorm> {
40    let w = w.dequantize(&w.device())?;
41    let rms = RmsNorm::from_w(w, eps)?;
42    Ok(rms)
43}
44
45struct LayerWeights {
46    attn_qkv: Arc<dyn QuantMethod>,
47    attn_output: Arc<dyn QuantMethod>,
48    attn_norm: RmsNorm,
49    ffn_norm: RmsNorm,
50    mlp: Mlp,
51    n_head: usize,
52    n_kv_head: usize,
53    head_dim: usize,
54    cos: Tensor,
55    sin: Tensor,
56    paged_attn: Option<PagedAttention>,
57    sdpa_params: SdpaParams,
58    dtype: DType,
59}
60
61impl LayerWeights {
62    fn apply_rotary_emb(&self, xs: &Tensor, seqlen_offsets: &[usize]) -> Result<Tensor> {
63        let (_b_sz, _h, seq_len, _n_embd) = xs.dims4()?;
64        let mut outputs = Vec::new();
65        for (i, offset) in seqlen_offsets.iter().enumerate() {
66            let cos = self.cos.narrow(0, *offset, seq_len)?;
67            let sin = self.sin.narrow(0, *offset, seq_len)?;
68            outputs.push(candle_nn::rotary_emb::rope(
69                &xs.i(i)?.unsqueeze(0)?.contiguous()?,
70                &cos,
71                &sin,
72            )?);
73        }
74        Tensor::cat(&outputs, 0)
75    }
76
77    fn forward_attn(
78        &self,
79        x: &Tensor,
80        mask: Option<&Tensor>,
81        seqlen_offsets: &[usize],
82        kv_cache: &mut KvCache,
83        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
84    ) -> Result<Tensor> {
85        let (b_sz, seq_len, _) = x.dims3()?;
86        let qkv = MatMul
87            .qmethod_matmul(x, &*self.attn_qkv)?
88            .to_dtype(self.dtype)?;
89        let query_pos = self.n_head * self.head_dim;
90        let q = qkv.narrow(D::Minus1, 0, query_pos)?;
91        let k = qkv.narrow(D::Minus1, query_pos, self.n_kv_head * self.head_dim)?;
92        let v = qkv.narrow(
93            D::Minus1,
94            query_pos + self.n_kv_head * self.head_dim,
95            self.n_kv_head * self.head_dim,
96        )?;
97
98        let (q, k, v) = if seq_len != 1 {
99            let q = q
100                .reshape((b_sz, seq_len, self.n_head, self.head_dim))?
101                .transpose(1, 2)?;
102            let k = k
103                .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
104                .transpose(1, 2)?;
105            let v = v
106                .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
107                .transpose(1, 2)?;
108            (q, k, v.contiguous()?)
109        } else {
110            let q = q.reshape((b_sz, self.n_head, seq_len, self.head_dim))?;
111            let k = k.reshape((b_sz, self.n_kv_head, seq_len, self.head_dim))?;
112            let v = v.reshape((b_sz, self.n_kv_head, seq_len, self.head_dim))?;
113            (q, k, v)
114        };
115
116        let q = self.apply_rotary_emb(&q, seqlen_offsets)?;
117        let k = self.apply_rotary_emb(&k, seqlen_offsets)?;
118        let y = match &self.paged_attn {
119            Some(paged_attn) => {
120                let ((key_cache, value_cache), input_metadata) = metadata.unwrap();
121                paged_attn.forward(
122                    &q,
123                    &k,
124                    &v,
125                    mask,
126                    Some(key_cache),
127                    Some(value_cache),
128                    input_metadata,
129                    &self.sdpa_params,
130                    None,
131                )?
132            }
133            None => {
134                let (k, v) = kv_cache.append(&k, &v)?;
135
136                Sdpa.run_attention(&q, &k, &v, mask, None, &self.sdpa_params)?
137            }
138        };
139
140        let y = if mask.is_some() {
141            y.transpose(1, 2)?.reshape((b_sz, seq_len, ()))?
142        } else {
143            y.reshape((b_sz, seq_len, ()))?
144        };
145        let y = MatMul.qmethod_matmul(&y.to_dtype(x.dtype())?, &*self.attn_output)?;
146        Ok(y)
147    }
148}
149
150pub struct ModelWeights {
151    tok_embeddings: Embedding,
152    layers: Vec<LayerWeights>,
153    output_norm: RmsNorm,
154    output: QMatMul,
155    mapper: Option<Box<dyn DeviceMapper + Send + Sync>>,
156    pub device: Device,
157    pub cache: EitherCache,
158    pub max_seq_len: usize,
159    dtype: DType,
160}
161
162fn precomput_freqs_cis(
163    head_dim: usize,
164    freq_base: f32,
165    device: &Device,
166    context_window: usize,
167    dtype: DType,
168) -> Result<(Tensor, Tensor)> {
169    let theta: Vec<_> = (0..head_dim)
170        .step_by(2)
171        .map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32))
172        .collect();
173    let theta = Tensor::new(theta.as_slice(), device)?;
174    let idx_theta = Tensor::arange(0, context_window as u32, device)?
175        .to_dtype(DType::F32)?
176        .reshape((context_window, 1))?
177        .matmul(&theta.reshape((1, theta.elem_count()))?)?;
178    let cos = idx_theta.cos()?.to_dtype(dtype)?;
179    let sin = idx_theta.sin()?.to_dtype(dtype)?;
180    Ok((cos, sin))
181}
182
183// phi3 `llm` fields:
184// https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#llm
185// NOTE: Types here do not match spec
186pub(crate) struct PropsGGUF {
187    pub head_count: usize,
188    pub head_count_kv: usize,
189    pub block_count: usize,
190    pub embedding_length: usize,
191    pub i_size: usize,
192    pub rope_dim: usize,
193    pub rms_eps: f64,
194    pub context_window: usize,
195}
196
197impl TryFrom<ContentMetadata<'_>> for PropsGGUF {
198    type Error = anyhow::Error;
199
200    fn try_from(c: ContentMetadata) -> std::result::Result<Self, Self::Error> {
201        c.verify_arch("phi3")?;
202
203        let required = [
204            "attention.head_count",
205            "attention.head_count_kv",
206            "block_count",
207            "embedding_length",
208            "feed_forward_length",
209            "rope.dimension_count",
210            "attention.layer_norm_rms_epsilon",
211            "context_length",
212        ];
213        c.has_required_keys(&required)?;
214
215        // NOTE: Values are not aligned with GGUFv3 types
216        // TODO: Normalize value types to spec
217        let props = Self {
218            head_count: c.get_value::<u32>("attention.head_count")? as usize,
219            head_count_kv: c.get_value::<u32>("attention.head_count_kv")? as usize,
220            block_count: c.get_value::<u32>("block_count")? as usize,
221            embedding_length: c.get_value::<u32>("embedding_length")? as usize,
222            i_size: c.get_value::<u32>("feed_forward_length")? as usize,
223            rope_dim: c.get_value::<u32>("rope.dimension_count")? as usize,
224            rms_eps: c.get_value::<f32>("attention.layer_norm_rms_epsilon")? as f64,
225            context_window: c.get_value::<u32>("context_length")? as usize,
226        };
227
228        Ok(props)
229    }
230}
231
232impl ModelConfig::FromGGUF for ModelWeights {
233    fn from_gguf<R: std::io::Seek + std::io::Read>(
234        mut ct: Content<'_, R>,
235        device: &Device,
236        mapper: Box<dyn DeviceMapper + Send + Sync>,
237        attention_mechanism: AttentionImplementation,
238        dtype: DType,
239    ) -> Result<Self> {
240        // Parameter extraction from metadata.
241        let metadata = ContentMetadata {
242            path_prefix: "phi3",
243            metadata: ct.get_metadata(),
244        };
245        let PropsGGUF {
246            head_count,
247            head_count_kv,
248            block_count,
249            embedding_length,
250            i_size,
251            rope_dim,
252            rms_eps,
253            context_window,
254        } = PropsGGUF::try_from(metadata).or_else(|err| candle_core::bail!("{err}"))?;
255
256        let (cos, sin) = precomput_freqs_cis(rope_dim, 10_000., device, context_window, dtype)?;
257
258        let tok_embeddings = ct.tensor("token_embd.weight", device)?;
259        let tok_embeddings = tok_embeddings.dequantize(device)?;
260        let output_norm = rms_norm(ct.tensor("output_norm.weight", device)?, rms_eps)?;
261        let output = QMatMul::from_qtensor(ct.tensor("output.weight", device)?)?;
262        let mut layers = Vec::with_capacity(block_count);
263        let head_dim = embedding_length / head_count;
264
265        for layer_idx in NiceProgressBar::<_, 'b'>(
266            0..block_count,
267            "Loading repeating layers",
268            &new_multi_progress(),
269        ) {
270            let prefix = format!("blk.{layer_idx}");
271            let device = mapper.device_for(layer_idx, false).unwrap_or(device);
272            let ffn_up =
273                QMatMul::from_qtensor(ct.tensor(&format!("{prefix}.ffn_up.weight"), device)?)?;
274            let ffn_down =
275                QMatMul::from_qtensor(ct.tensor(&format!("{prefix}.ffn_down.weight"), device)?)?;
276            let QMatMul::QTensor(ffn_up_w) = ffn_up else {
277                unreachable!()
278            };
279            let QMatMul::QTensor(ffn_down_w) = ffn_down else {
280                unreachable!()
281            };
282            let mlp = Mlp {
283                ffn_up: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
284                    q_weight: ffn_up_w,
285                    b: None,
286                })?),
287                ffn_down: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
288                    q_weight: ffn_down_w,
289                    b: None,
290                })?),
291                i_size,
292            };
293            let attn_norm = rms_norm(
294                ct.tensor(&format!("{prefix}.attn_norm.weight"), device)?,
295                rms_eps,
296            )?;
297            let ffn_norm = rms_norm(
298                ct.tensor(&format!("{prefix}.ffn_norm.weight"), device)?,
299                rms_eps,
300            )?;
301            let paged_attn = match &attention_mechanism {
302                AttentionImplementation::Eager => None,
303                AttentionImplementation::PagedAttention => {
304                    Some(PagedAttention::new(head_dim, device, None)?)
305                }
306            };
307            let qkv =
308                QMatMul::from_qtensor(ct.tensor(&format!("{prefix}.attn_qkv.weight"), device)?)?;
309            let out =
310                QMatMul::from_qtensor(ct.tensor(&format!("{prefix}.attn_output.weight"), device)?)?;
311            let QMatMul::QTensor(qkv_w) = qkv.clone() else {
312                unreachable!()
313            };
314            let QMatMul::QTensor(out_w) = out.clone() else {
315                unreachable!()
316            };
317            layers.push(LayerWeights {
318                attn_qkv: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
319                    q_weight: qkv_w,
320                    b: None,
321                })?),
322                attn_output: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
323                    q_weight: out_w,
324                    b: None,
325                })?),
326                attn_norm,
327                ffn_norm,
328                mlp,
329                n_head: head_count,
330                n_kv_head: head_count_kv,
331                head_dim,
332                cos: cos.to_device(device)?,
333                sin: sin.to_device(device)?,
334                paged_attn,
335                sdpa_params: SdpaParams {
336                    n_kv_groups: head_count / head_count_kv,
337                    softcap: None,
338                    softmax_scale: 1.0 / (head_dim as f32).sqrt(),
339                    sliding_window: Some(context_window),
340                },
341                dtype,
342            })
343        }
344        Ok(Self {
345            tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
346            layers,
347            output_norm,
348            output,
349            mapper: Some(mapper),
350            device: device.clone(),
351            cache: EitherCache::Normal(NormalCache::new_sliding(
352                block_count,
353                context_window,
354                Some(context_window),
355            )),
356            max_seq_len: context_window,
357            dtype,
358        })
359    }
360}
361
362impl ModelWeights {
363    pub fn forward(
364        &self,
365        input_ids: &Tensor,
366        seqlen_offsets: &[usize],
367        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
368    ) -> Result<Tensor> {
369        let (_b_sz, seq_len) = input_ids.dims2()?;
370        let mut xs = self.tok_embeddings.forward(input_ids)?;
371        let cache = &mut self.cache.normal().0;
372        let mask = CausalMasker.make_sliding_window_causal_mask_matrix(
373            input_ids,
374            metadata
375                .as_ref()
376                .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
377                .unwrap_or(cache as &dyn PastKvLenCache),
378            Some(self.max_seq_len),
379            self.dtype,
380            self.layers[0].n_head,
381        )?;
382        let mask = mask.filter(|_| {
383            metadata
384                .as_ref()
385                .map(|(_, meta)| meta.is_first_prompt_chunk)
386                .unwrap_or(true)
387        });
388        for (i, layer) in self.layers.iter().enumerate() {
389            if let Some(ref mapper) = self.mapper {
390                xs = mapper.map(xs, i)?;
391            }
392            let residual = &xs;
393            let ys = xs.apply(&layer.attn_norm)?;
394            let ys = layer.forward_attn(
395                &ys,
396                mask.as_ref()
397                    .map(|m| m.to_device(xs.device()).unwrap())
398                    .as_ref(),
399                seqlen_offsets,
400                &mut cache[i],
401                metadata
402                    .as_ref()
403                    .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
404            )?;
405            let ys = (ys + residual)?;
406            let residual = &ys;
407            let ys = ys.apply(&layer.ffn_norm)?;
408            let ys = layer.mlp.forward(&ys)?;
409            xs = (ys + residual)?
410        }
411        let xs = xs.apply(&self.output_norm)?.i((.., seq_len - 1, ..))?;
412        MatMul.qmatmul(&xs, &self.output)
413    }
414}