mistralrs_core/models/
quantized_phi2.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::sync::Arc;
4
5use candle_core::quantized::QMatMul;
6use candle_core::quantized::QTensor;
7use candle_core::{DType, Device, IndexOp, Module, Result, Tensor, D};
8use candle_nn::{Embedding, LayerNorm};
9use mistralrs_quant::GgufMatMul;
10use mistralrs_quant::QuantMethod;
11use mistralrs_quant::QuantMethodConfig;
12
13use crate::attention::SdpaParams;
14use crate::device_map::DeviceMapper;
15use crate::gguf::Content;
16use crate::layers::MatMul;
17use crate::layers::Sdpa;
18use crate::layers::{CausalMasker, QLinear};
19use crate::layers_masker::PastKvLenCache;
20use crate::paged_attention::AttentionImplementation;
21use crate::paged_attention::PagedAttention;
22use crate::pipeline::extract_logits;
23use crate::pipeline::text_models_inputs_processor::PagedAttentionInputMetadata;
24use crate::pipeline::EitherCache;
25use crate::pipeline::KvCache;
26use crate::pipeline::NormalCache;
27use crate::utils::gguf_metadata::ContentMetadata;
28use crate::utils::model_config as ModelConfig;
29use crate::utils::progress::{new_multi_progress, NiceProgressBar};
30
31pub const DEFAULT_MAX_SEQ_LEN: usize = 4096;
32
33#[derive(Clone)]
34struct Mlp {
35    ffn_up: Arc<dyn QuantMethod>,
36    ffn_down: Arc<dyn QuantMethod>,
37}
38
39impl Module for Mlp {
40    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
41        MatMul.qmethod_matmul(&MatMul.qmethod_matmul(xs, &*self.ffn_up)?, &*self.ffn_down)
42    }
43}
44
45struct LayerWeights {
46    attn_qkv: Arc<dyn QuantMethod>,
47    attn_output: Arc<dyn QuantMethod>,
48    attn_norm: LayerNorm,
49    mlp: Mlp,
50    n_head: usize,
51    head_dim: usize,
52    cos: Tensor,
53    sin: Tensor,
54    rope_dim: usize,
55    paged_attn: Option<PagedAttention>,
56    sdpa_params: SdpaParams,
57    dtype: DType,
58}
59
60impl LayerWeights {
61    fn forward(&self, xs: &Tensor, start_offsets: &[usize]) -> Result<Tensor> {
62        let (_b_sz, _n_head, seq_len, _n_embd) = xs.dims4()?;
63        let xs_rot = xs.i((.., .., .., ..self.rope_dim))?;
64        let xs_pass = xs.i((.., .., .., self.rope_dim..))?;
65        let mut chunks = Vec::new();
66        for (b, offset) in (0..xs.dim(0)?).zip(start_offsets) {
67            let cos = self.cos.narrow(0, *offset, seq_len)?;
68            let sin = self.sin.narrow(0, *offset, seq_len)?;
69            let xs_rot =
70                candle_nn::rotary_emb::rope(&xs_rot.i(b)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
71            chunks.push(Tensor::cat(&[&xs_rot, &xs_pass], D::Minus1)?);
72        }
73        Tensor::cat(&chunks, 0)?.contiguous()
74    }
75
76    fn forward_attn(
77        &self,
78        x: &Tensor,
79        mask: Option<&Tensor>,
80        seqlen_offsets: &[usize],
81        kv_cache: &mut KvCache,
82        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
83    ) -> Result<Tensor> {
84        let (b_sz, seq_len, _) = x.dims3()?;
85        let qkv = self
86            .attn_qkv
87            .forward(x)?
88            .reshape((b_sz, seq_len, 3, self.n_head, self.head_dim))?
89            .to_dtype(self.dtype)?;
90
91        let q = qkv.i((.., .., 0))?.transpose(1, 2)?;
92        let k = qkv.i((.., .., 1))?.transpose(1, 2)?;
93        let v = qkv.i((.., .., 2))?.transpose(1, 2)?;
94        // This call to contiguous ensures that the fast kernel can be called below. It's
95        // actually a no-op except when processing the initial prompt so has no significant
96        // impact on performance.
97        let v = v.contiguous()?;
98
99        let q = self.forward(&q, seqlen_offsets)?.contiguous()?;
100        let k = self.forward(&k, seqlen_offsets)?;
101
102        let y = match &self.paged_attn {
103            Some(paged_attn) => {
104                let ((key_cache, value_cache), input_metadata) = metadata.unwrap();
105                paged_attn.forward(
106                    &q,
107                    &k,
108                    &v,
109                    mask,
110                    Some(key_cache),
111                    Some(value_cache),
112                    input_metadata,
113                    &self.sdpa_params,
114                    None,
115                )?
116            }
117            None => {
118                let (k, v) = kv_cache.append(&k, &v)?;
119
120                Sdpa.run_attention(&q, &k, &v, mask, None, &self.sdpa_params)?
121            }
122        };
123
124        let y = if mask.is_some() {
125            y.transpose(1, 2)?.reshape((b_sz, seq_len, ()))?
126        } else {
127            y.reshape((b_sz, seq_len, ()))?
128        };
129        let y = self.attn_output.forward(&y.to_dtype(x.dtype())?)?;
130        Ok(y)
131    }
132}
133
134pub struct ModelWeights {
135    tok_embeddings: Embedding,
136    layers: Vec<LayerWeights>,
137    output_norm: LayerNorm,
138    output: QLinear,
139    pub device: Device,
140    pub cache: EitherCache,
141    pub max_seq_len: usize,
142    mapper: Box<dyn DeviceMapper + Send + Sync>,
143    dtype: DType,
144}
145
146fn precomput_freqs_cis(
147    head_dim: usize,
148    freq_base: f32,
149    device: &Device,
150    max_seq_len: usize,
151    dtype: DType,
152) -> Result<(Tensor, Tensor)> {
153    let theta: Vec<_> = (0..head_dim)
154        .step_by(2)
155        .map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32))
156        .collect();
157    let theta = Tensor::new(theta.as_slice(), device)?;
158    let idx_theta = Tensor::arange(0, max_seq_len as u32, device)?
159        .to_dtype(DType::F32)?
160        .reshape((max_seq_len, 1))?
161        .matmul(&theta.reshape((1, theta.elem_count()))?)?;
162    let cos = idx_theta.cos()?.to_dtype(dtype)?;
163    let sin = idx_theta.sin()?.to_dtype(dtype)?;
164    Ok((cos, sin))
165}
166
167fn layer_norm(w: QTensor, b: QTensor, eps: f64) -> Result<LayerNorm> {
168    let w = w.dequantize(&w.device())?;
169    let b = b.dequantize(&b.device())?;
170    let ln = LayerNorm::new(w, b, eps);
171    Ok(ln)
172}
173
174// phi2 `llm` fields:
175// https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#llm
176// NOTE: Types here do not match spec
177struct PropsGGUF {
178    head_count: usize,
179    head_count_kv: usize,
180    block_count: usize,
181    embedding_length: usize,
182    rope_dim: usize,
183    ln_eps: f64,
184    max_seq_len: usize,
185}
186
187impl TryFrom<ContentMetadata<'_>> for PropsGGUF {
188    type Error = anyhow::Error;
189
190    fn try_from(c: ContentMetadata) -> std::result::Result<Self, Self::Error> {
191        c.verify_arch("phi2")?;
192
193        let required = [
194            "attention.head_count",
195            "attention.head_count_kv",
196            "block_count",
197            "embedding_length",
198            "rope.dimension_count",
199            "attention.layer_norm_rms_epsilon",
200            "context_length",
201        ];
202        c.has_required_keys(&required)?;
203
204        // NOTE: Values are not aligned with GGUFv3 types
205        // TODO: Normalize value types to spec
206        let props = Self {
207            head_count: c.get_value::<u32>("attention.head_count")? as usize,
208            head_count_kv: c.get_value::<u32>("attention.head_count_kv")? as usize,
209            block_count: c.get_value::<u32>("block_count")? as usize,
210            embedding_length: c.get_value::<u32>("embedding_length")? as usize,
211            rope_dim: c.get_value::<u32>("rope.dimension_count")? as usize,
212            ln_eps: c.get_value::<f32>("attention.layer_norm_rms_epsilon")? as f64,
213            max_seq_len: c
214                .get_value::<u64>("context_length")
215                .ok()
216                .unwrap_or(DEFAULT_MAX_SEQ_LEN as u64) as usize,
217        };
218
219        Ok(props)
220    }
221}
222
223impl ModelConfig::FromGGUF for ModelWeights {
224    fn from_gguf<R: std::io::Seek + std::io::Read>(
225        mut ct: Content<'_, R>,
226        device: &Device,
227        mapper: Box<dyn DeviceMapper + Send + Sync>,
228        attention_mechanism: AttentionImplementation,
229        dtype: DType,
230    ) -> Result<Self> {
231        // Parameter extraction from metadata.
232        let metadata = ContentMetadata {
233            path_prefix: "phi2",
234            metadata: ct.get_metadata(),
235        };
236        let PropsGGUF {
237            head_count,
238            head_count_kv,
239            block_count,
240            embedding_length,
241            rope_dim,
242            ln_eps,
243            max_seq_len,
244        } = PropsGGUF::try_from(metadata).or_else(|err| candle_core::bail!("{err}"))?;
245
246        let (cos, sin) = precomput_freqs_cis(rope_dim, 10_000., device, max_seq_len, dtype)?;
247
248        let tok_embeddings = ct.tensor("token_embd.weight", device)?;
249        let tok_embeddings = tok_embeddings.dequantize(device)?;
250        let output_norm = layer_norm(
251            ct.tensor("output_norm.weight", device)?,
252            ct.tensor("output_norm.bias", device)?,
253            ln_eps,
254        )?;
255        let output = QLinear::new(&mut ct, "output", device)?;
256        let mut layers = Vec::with_capacity(block_count);
257        let head_dim = embedding_length / head_count;
258
259        for layer_idx in NiceProgressBar::<_, 'b'>(
260            0..block_count,
261            "Loading repeating layers",
262            &new_multi_progress(),
263        ) {
264            let prefix = format!("blk.{layer_idx}");
265            let device = mapper.device_for(layer_idx, false).unwrap_or(device);
266
267            let ffn_up = QLinear::new(&mut ct, &format!("{prefix}.ffn_up"), device)?;
268            let ffn_down = QLinear::new(&mut ct, &format!("{prefix}.ffn_down"), device)?;
269            let QMatMul::QTensor(ffn_up_w) = ffn_up.inner_ref().clone() else {
270                unreachable!()
271            };
272            let QMatMul::QTensor(ffn_down_w) = ffn_down.inner_ref().clone() else {
273                unreachable!()
274            };
275            let mlp = Mlp {
276                ffn_up: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
277                    q_weight: ffn_up_w,
278                    b: ffn_up.bias().cloned(),
279                })?),
280                ffn_down: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
281                    q_weight: ffn_down_w,
282                    b: ffn_down.bias().cloned(),
283                })?),
284            };
285            let attn_norm = layer_norm(
286                ct.tensor(&format!("{prefix}.attn_norm.weight"), device)?,
287                ct.tensor(&format!("{prefix}.attn_norm.bias"), device)?,
288                ln_eps,
289            )?;
290            let paged_attn = match &attention_mechanism {
291                AttentionImplementation::Eager => None,
292                AttentionImplementation::PagedAttention => {
293                    Some(PagedAttention::new(head_dim, device, None)?)
294                }
295            };
296            let qkv = QLinear::new(&mut ct, &format!("{prefix}.attn_qkv"), device)?;
297            let out = QLinear::new(&mut ct, &format!("{prefix}.attn_output"), device)?;
298            let QMatMul::QTensor(qkv_w) = qkv.inner_ref().clone() else {
299                unreachable!()
300            };
301            let QMatMul::QTensor(out_w) = out.inner_ref().clone() else {
302                unreachable!()
303            };
304            layers.push(LayerWeights {
305                attn_qkv: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
306                    q_weight: qkv_w,
307                    b: qkv.bias().cloned(),
308                })?),
309                attn_output: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
310                    q_weight: out_w,
311                    b: out.bias().cloned(),
312                })?),
313                attn_norm,
314                mlp,
315                n_head: head_count,
316                head_dim,
317                cos: cos.clone().to_device(device)?,
318                sin: sin.clone().to_device(device)?,
319                rope_dim,
320                paged_attn,
321                sdpa_params: SdpaParams {
322                    n_kv_groups: head_count / head_count_kv,
323                    softcap: None,
324                    softmax_scale: 1.0 / (head_dim as f32).sqrt(),
325                    sliding_window: None,
326                },
327                dtype,
328            })
329        }
330        Ok(Self {
331            tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
332            layers,
333            output_norm,
334            output,
335            device: device.clone(),
336            cache: EitherCache::Normal(NormalCache::new(block_count, max_seq_len)),
337            max_seq_len,
338            mapper,
339            dtype,
340        })
341    }
342}
343
344impl ModelWeights {
345    pub fn forward(
346        &self,
347        input_ids: &Tensor,
348        seqlen_offsets: &[usize],
349        context_lens: Vec<(usize, usize)>,
350        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
351    ) -> Result<Tensor> {
352        let mut xs = self.tok_embeddings.forward(input_ids)?;
353        let cache = &mut self.cache.normal().0;
354        let mask = CausalMasker.make_causal_mask_matrix(
355            input_ids,
356            metadata
357                .as_ref()
358                .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
359                .unwrap_or(cache as &dyn PastKvLenCache),
360            self.dtype,
361            self.layers[0].n_head,
362        )?;
363        let mask = mask.filter(|_| {
364            metadata
365                .as_ref()
366                .map(|(_, meta)| meta.is_first_prompt_chunk)
367                .unwrap_or(true)
368        });
369        for (i, layer) in self.layers.iter().enumerate() {
370            xs = self.mapper.map(xs, i)?;
371            let residual = &xs;
372            let xs_norm = xs.apply(&layer.attn_norm)?;
373            let attn_outputs = layer.forward_attn(
374                &xs_norm,
375                mask.as_ref()
376                    .map(|m| m.to_device(xs.device()).unwrap())
377                    .as_ref(),
378                seqlen_offsets,
379                &mut cache[i],
380                metadata
381                    .as_ref()
382                    .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
383            )?;
384            let feed_forward_hidden_states = layer.mlp.forward(&xs_norm)?;
385            xs = (attn_outputs + feed_forward_hidden_states + residual)?
386        }
387        let xs = xs.to_device(&self.device)?;
388        let xs = extract_logits(&xs.apply(&self.output_norm)?, context_lens)?;
389        self.output.forward(&xs)
390    }
391}