mistralrs_core/models/
quantized_phi3.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::sync::Arc;
4
5use crate::attention::SdpaParams;
6use crate::device_map::DeviceMapper;
7use crate::gguf::Content;
8use crate::layers::{CausalMasker, MatMul, RmsNorm, Sdpa};
9use crate::layers_masker::PastKvLenCache;
10use crate::paged_attention::{AttentionImplementation, PagedAttention};
11use crate::pipeline::text_models_inputs_processor::PagedAttentionInputMetadata;
12use crate::pipeline::{EitherCache, KvCache, NormalCache};
13use crate::utils::gguf_metadata::ContentMetadata;
14use crate::utils::model_config as ModelConfig;
15use crate::utils::progress::NiceProgressBar;
16use candle_core::quantized::QMatMul;
17use candle_core::quantized::QTensor;
18use candle_core::{DType, Device, IndexOp, Module, Result, Tensor, D};
19use candle_nn::Embedding;
20use indicatif::MultiProgress;
21use mistralrs_quant::{GgufMatMul, QuantMethod, QuantMethodConfig};
22
23#[derive(Clone)]
24struct Mlp {
25    ffn_up: Arc<dyn QuantMethod>,
26    ffn_down: Arc<dyn QuantMethod>,
27    i_size: usize,
28}
29
30impl Module for Mlp {
31    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
32        let up_states = MatMul.qmethod_matmul(xs, &*self.ffn_up)?;
33        let gate = up_states.narrow(D::Minus1, 0, self.i_size)?;
34        let up_states = up_states.narrow(D::Minus1, self.i_size, self.i_size)?;
35        let up_states = (up_states * gate.silu()?)?;
36        MatMul.qmethod_matmul(&up_states, &*self.ffn_down)
37    }
38}
39
40fn rms_norm(w: QTensor, eps: f64) -> Result<RmsNorm> {
41    let w = w.dequantize(&w.device())?;
42    let rms = RmsNorm::from_w(w, eps)?;
43    Ok(rms)
44}
45
46struct LayerWeights {
47    attn_qkv: Arc<dyn QuantMethod>,
48    attn_output: Arc<dyn QuantMethod>,
49    attn_norm: RmsNorm,
50    ffn_norm: RmsNorm,
51    mlp: Mlp,
52    n_head: usize,
53    n_kv_head: usize,
54    head_dim: usize,
55    cos: Tensor,
56    sin: Tensor,
57    paged_attn: Option<PagedAttention>,
58    sdpa_params: SdpaParams,
59    dtype: DType,
60}
61
62impl LayerWeights {
63    fn apply_rotary_emb(&self, xs: &Tensor, seqlen_offsets: &[usize]) -> Result<Tensor> {
64        let (_b_sz, _h, seq_len, _n_embd) = xs.dims4()?;
65        let mut outputs = Vec::new();
66        for (i, offset) in seqlen_offsets.iter().enumerate() {
67            let cos = self.cos.narrow(0, *offset, seq_len)?;
68            let sin = self.sin.narrow(0, *offset, seq_len)?;
69            outputs.push(candle_nn::rotary_emb::rope(
70                &xs.i(i)?.unsqueeze(0)?.contiguous()?,
71                &cos,
72                &sin,
73            )?);
74        }
75        Tensor::cat(&outputs, 0)
76    }
77
78    fn forward_attn(
79        &self,
80        x: &Tensor,
81        mask: Option<&Tensor>,
82        seqlen_offsets: &[usize],
83        kv_cache: &mut KvCache,
84        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
85    ) -> Result<Tensor> {
86        let (b_sz, seq_len, _) = x.dims3()?;
87        let qkv = MatMul
88            .qmethod_matmul(x, &*self.attn_qkv)?
89            .to_dtype(self.dtype)?;
90        let query_pos = self.n_head * self.head_dim;
91        let q = qkv.narrow(D::Minus1, 0, query_pos)?;
92        let k = qkv.narrow(D::Minus1, query_pos, self.n_kv_head * self.head_dim)?;
93        let v = qkv.narrow(
94            D::Minus1,
95            query_pos + self.n_kv_head * self.head_dim,
96            self.n_kv_head * self.head_dim,
97        )?;
98
99        let (q, k, v) = if seq_len != 1 {
100            let q = q
101                .reshape((b_sz, seq_len, self.n_head, self.head_dim))?
102                .transpose(1, 2)?;
103            let k = k
104                .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
105                .transpose(1, 2)?;
106            let v = v
107                .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
108                .transpose(1, 2)?;
109            (q, k, v.contiguous()?)
110        } else {
111            let q = q.reshape((b_sz, self.n_head, seq_len, self.head_dim))?;
112            let k = k.reshape((b_sz, self.n_kv_head, seq_len, self.head_dim))?;
113            let v = v.reshape((b_sz, self.n_kv_head, seq_len, self.head_dim))?;
114            (q, k, v)
115        };
116
117        let q = self.apply_rotary_emb(&q, seqlen_offsets)?;
118        let k = self.apply_rotary_emb(&k, seqlen_offsets)?;
119        let y = match &self.paged_attn {
120            Some(paged_attn) => {
121                let ((key_cache, value_cache), input_metadata) = metadata.unwrap();
122                paged_attn.forward(
123                    &q,
124                    &k,
125                    &v,
126                    mask,
127                    Some(key_cache),
128                    Some(value_cache),
129                    input_metadata,
130                    &self.sdpa_params,
131                    None,
132                )?
133            }
134            None => {
135                let (k, v) = kv_cache.append(&k, &v)?;
136
137                Sdpa.run_attention(&q, &k, &v, mask, None, &self.sdpa_params)?
138            }
139        };
140
141        let y = if mask.is_some() {
142            y.transpose(1, 2)?.reshape((b_sz, seq_len, ()))?
143        } else {
144            y.reshape((b_sz, seq_len, ()))?
145        };
146        let y = MatMul.qmethod_matmul(&y.to_dtype(x.dtype())?, &*self.attn_output)?;
147        Ok(y)
148    }
149}
150
151pub struct ModelWeights {
152    tok_embeddings: Embedding,
153    layers: Vec<LayerWeights>,
154    output_norm: RmsNorm,
155    output: QMatMul,
156    mapper: Option<Box<dyn DeviceMapper + Send + Sync>>,
157    pub device: Device,
158    pub cache: EitherCache,
159    pub max_seq_len: usize,
160    dtype: DType,
161}
162
163fn precomput_freqs_cis(
164    head_dim: usize,
165    freq_base: f32,
166    device: &Device,
167    context_window: usize,
168    dtype: DType,
169) -> Result<(Tensor, Tensor)> {
170    let theta: Vec<_> = (0..head_dim)
171        .step_by(2)
172        .map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32))
173        .collect();
174    let theta = Tensor::new(theta.as_slice(), device)?;
175    let idx_theta = Tensor::arange(0, context_window as u32, device)?
176        .to_dtype(DType::F32)?
177        .reshape((context_window, 1))?
178        .matmul(&theta.reshape((1, theta.elem_count()))?)?;
179    let cos = idx_theta.cos()?.to_dtype(dtype)?;
180    let sin = idx_theta.sin()?.to_dtype(dtype)?;
181    Ok((cos, sin))
182}
183
184// phi3 `llm` fields:
185// https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#llm
186// NOTE: Types here do not match spec
187pub(crate) struct PropsGGUF {
188    pub head_count: usize,
189    pub head_count_kv: usize,
190    pub block_count: usize,
191    pub embedding_length: usize,
192    pub i_size: usize,
193    pub rope_dim: usize,
194    pub rms_eps: f64,
195    pub context_window: usize,
196}
197
198impl TryFrom<ContentMetadata<'_>> for PropsGGUF {
199    type Error = anyhow::Error;
200
201    fn try_from(c: ContentMetadata) -> std::result::Result<Self, Self::Error> {
202        c.verify_arch("phi3")?;
203
204        let required = [
205            "attention.head_count",
206            "attention.head_count_kv",
207            "block_count",
208            "embedding_length",
209            "feed_forward_length",
210            "rope.dimension_count",
211            "attention.layer_norm_rms_epsilon",
212            "context_length",
213        ];
214        c.has_required_keys(&required)?;
215
216        // NOTE: Values are not aligned with GGUFv3 types
217        // TODO: Normalize value types to spec
218        let props = Self {
219            head_count: c.get_value::<u32>("attention.head_count")? as usize,
220            head_count_kv: c.get_value::<u32>("attention.head_count_kv")? as usize,
221            block_count: c.get_value::<u32>("block_count")? as usize,
222            embedding_length: c.get_value::<u32>("embedding_length")? as usize,
223            i_size: c.get_value::<u32>("feed_forward_length")? as usize,
224            rope_dim: c.get_value::<u32>("rope.dimension_count")? as usize,
225            rms_eps: c.get_value::<f32>("attention.layer_norm_rms_epsilon")? as f64,
226            context_window: c.get_value::<u32>("context_length")? as usize,
227        };
228
229        Ok(props)
230    }
231}
232
233impl ModelConfig::FromGGUF for ModelWeights {
234    fn from_gguf<R: std::io::Seek + std::io::Read>(
235        mut ct: Content<'_, R>,
236        device: &Device,
237        mapper: Box<dyn DeviceMapper + Send + Sync>,
238        attention_mechanism: AttentionImplementation,
239        dtype: DType,
240    ) -> Result<Self> {
241        // Parameter extraction from metadata.
242        let metadata = ContentMetadata {
243            path_prefix: "phi3",
244            metadata: ct.get_metadata(),
245        };
246        let PropsGGUF {
247            head_count,
248            head_count_kv,
249            block_count,
250            embedding_length,
251            i_size,
252            rope_dim,
253            rms_eps,
254            context_window,
255        } = PropsGGUF::try_from(metadata).or_else(|err| candle_core::bail!("{err}"))?;
256
257        let (cos, sin) = precomput_freqs_cis(rope_dim, 10_000., device, context_window, dtype)?;
258
259        let tok_embeddings = ct.tensor("token_embd.weight", device)?;
260        let tok_embeddings = tok_embeddings.dequantize(device)?;
261        let output_norm = rms_norm(ct.tensor("output_norm.weight", device)?, rms_eps)?;
262        let output = QMatMul::from_qtensor(ct.tensor("output.weight", device)?)?;
263        let mut layers = Vec::with_capacity(block_count);
264        let head_dim = embedding_length / head_count;
265
266        for layer_idx in NiceProgressBar::<_, 'b'>(
267            0..block_count,
268            "Loading repeating layers",
269            &MultiProgress::new(),
270        ) {
271            let prefix = format!("blk.{layer_idx}");
272            let device = mapper.device_for(layer_idx, false).unwrap_or(device);
273            let ffn_up =
274                QMatMul::from_qtensor(ct.tensor(&format!("{prefix}.ffn_up.weight"), device)?)?;
275            let ffn_down =
276                QMatMul::from_qtensor(ct.tensor(&format!("{prefix}.ffn_down.weight"), device)?)?;
277            let QMatMul::QTensor(ffn_up_w) = ffn_up else {
278                unreachable!()
279            };
280            let QMatMul::QTensor(ffn_down_w) = ffn_down else {
281                unreachable!()
282            };
283            let mlp = Mlp {
284                ffn_up: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
285                    q_weight: ffn_up_w,
286                    b: None,
287                })?),
288                ffn_down: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
289                    q_weight: ffn_down_w,
290                    b: None,
291                })?),
292                i_size,
293            };
294            let attn_norm = rms_norm(
295                ct.tensor(&format!("{prefix}.attn_norm.weight"), device)?,
296                rms_eps,
297            )?;
298            let ffn_norm = rms_norm(
299                ct.tensor(&format!("{prefix}.ffn_norm.weight"), device)?,
300                rms_eps,
301            )?;
302            let paged_attn = match &attention_mechanism {
303                AttentionImplementation::Eager => None,
304                AttentionImplementation::PagedAttention => {
305                    Some(PagedAttention::new(head_dim, device, None)?)
306                }
307            };
308            let qkv =
309                QMatMul::from_qtensor(ct.tensor(&format!("{prefix}.attn_qkv.weight"), device)?)?;
310            let out =
311                QMatMul::from_qtensor(ct.tensor(&format!("{prefix}.attn_output.weight"), device)?)?;
312            let QMatMul::QTensor(qkv_w) = qkv.clone() else {
313                unreachable!()
314            };
315            let QMatMul::QTensor(out_w) = out.clone() else {
316                unreachable!()
317            };
318            layers.push(LayerWeights {
319                attn_qkv: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
320                    q_weight: qkv_w,
321                    b: None,
322                })?),
323                attn_output: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
324                    q_weight: out_w,
325                    b: None,
326                })?),
327                attn_norm,
328                ffn_norm,
329                mlp,
330                n_head: head_count,
331                n_kv_head: head_count_kv,
332                head_dim,
333                cos: cos.to_device(device)?,
334                sin: sin.to_device(device)?,
335                paged_attn,
336                sdpa_params: SdpaParams {
337                    n_kv_groups: head_count / head_count_kv,
338                    use_flash_attn: false,
339                    softcap: None,
340                    softmax_scale: 1.0 / (head_dim as f32).sqrt(),
341                    sliding_window: Some(context_window),
342                },
343                dtype,
344            })
345        }
346        Ok(Self {
347            tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
348            layers,
349            output_norm,
350            output,
351            mapper: Some(mapper),
352            device: device.clone(),
353            cache: EitherCache::Normal(NormalCache::new_sliding(
354                block_count,
355                context_window,
356                Some(context_window),
357            )),
358            max_seq_len: context_window,
359            dtype,
360        })
361    }
362}
363
364impl ModelWeights {
365    pub fn forward(
366        &self,
367        input_ids: &Tensor,
368        seqlen_offsets: &[usize],
369        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
370    ) -> Result<Tensor> {
371        let (_b_sz, seq_len) = input_ids.dims2()?;
372        let mut xs = self.tok_embeddings.forward(input_ids)?;
373        let cache = &mut self.cache.normal().0;
374        let mask = CausalMasker.make_sliding_window_causal_mask_matrix(
375            input_ids,
376            metadata
377                .as_ref()
378                .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
379                .unwrap_or(cache as &dyn PastKvLenCache),
380            Some(self.max_seq_len),
381            self.dtype,
382            self.layers[0].n_head,
383        )?;
384        let mask = mask.filter(|_| {
385            metadata
386                .as_ref()
387                .map(|(_, meta)| meta.is_first_prompt_chunk)
388                .unwrap_or(true)
389        });
390        for (i, layer) in self.layers.iter().enumerate() {
391            if let Some(ref mapper) = self.mapper {
392                xs = mapper.map(xs, i)?;
393            }
394            let residual = &xs;
395            let ys = xs.apply(&layer.attn_norm)?;
396            let ys = layer.forward_attn(
397                &ys,
398                mask.as_ref()
399                    .map(|m| m.to_device(xs.device()).unwrap())
400                    .as_ref(),
401                seqlen_offsets,
402                &mut cache[i],
403                metadata
404                    .as_ref()
405                    .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
406            )?;
407            let ys = (ys + residual)?;
408            let residual = &ys;
409            let ys = ys.apply(&layer.ffn_norm)?;
410            let ys = layer.mlp.forward(&ys)?;
411            xs = (ys + residual)?
412        }
413        let xs = xs.apply(&self.output_norm)?.i((.., seq_len - 1, ..))?;
414        MatMul.qmatmul(&xs, &self.output)
415    }
416}