mistralrs_core/xlora_models/
quantized_phi3.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::collections::HashMap;
4
5use crate::attention::SdpaParams;
6use crate::device_map::DeviceMapper;
7use crate::gguf::Content;
8use crate::layers::CausalMasker;
9use crate::layers::RmsNorm;
10use crate::layers::Sdpa;
11use crate::lora::get_lora_cfg;
12use crate::lora::LinearLayerLike;
13use crate::lora::LoraConfig;
14use crate::lora::Merge;
15use crate::lora::Ordering;
16use crate::lora::QLoraLinear;
17use crate::pipeline::extract_logits;
18use crate::pipeline::text_models_inputs_processor::FlashParams;
19use crate::pipeline::EitherCache;
20use crate::utils::progress::NiceProgressBar;
21use candle_core::quantized::QMatMul;
22use candle_core::quantized::QTensor;
23use candle_core::{DType, Device, IndexOp, Module, Result, Tensor, D};
24use candle_nn::Embedding;
25use indicatif::MultiProgress;
26use mistralrs_quant::ShardedVarBuilder;
27use tqdm::Iter;
28use tracing::info;
29
30use super::classifier::XLoraClassifier;
31use super::verify_sanity_adapters;
32use super::Cache;
33use super::NonGranularState;
34use super::ScalingsMaker;
35use super::XLoraConfig;
36use crate::models::quantized_phi3::PropsGGUF;
37use crate::utils::gguf_metadata::ContentMetadata;
38use crate::utils::model_config as ModelConfig;
39
40const SUPPORTED_LAYERS: [&str; 5] = [
41    "self_attn.qkv_proj",
42    "self_attn.o_proj",
43    "mlp.gate_up_proj",
44    "mlp.down_proj",
45    "lm_head",
46];
47
48#[derive(Debug)]
49struct Mlp {
50    ffn_up: QLoraLinear,
51    ffn_down: QLoraLinear,
52    i_size: usize,
53}
54
55impl Mlp {
56    fn forward(
57        &self,
58        xs: &Tensor,
59        scalings: Option<Tensor>,
60        global_scaling_weight: f64,
61        is_scaling_pass: Option<f64>,
62    ) -> Result<Tensor> {
63        let up_states = self.ffn_up.lora_forward(
64            xs,
65            scalings.clone(),
66            global_scaling_weight,
67            is_scaling_pass,
68        )?;
69        let gate = up_states.narrow(D::Minus1, 0, self.i_size)?;
70        let up_states = up_states.narrow(D::Minus1, self.i_size, self.i_size)?;
71        let up_states = (up_states * gate.silu()?)?;
72        self.ffn_down
73            .lora_forward(&up_states, scalings, global_scaling_weight, is_scaling_pass)
74    }
75}
76
77fn rms_norm(w: QTensor, eps: f64) -> Result<RmsNorm> {
78    let w = w.dequantize(&w.device())?;
79    let rms = RmsNorm::from_w(w, eps)?;
80    Ok(rms)
81}
82
83struct LayerWeights {
84    attn_qkv: QLoraLinear,
85    attn_output: QLoraLinear,
86    attn_norm: RmsNorm,
87    ffn_norm: RmsNorm,
88    mlp: Mlp,
89    n_head: usize,
90    n_kv_head: usize,
91    head_dim: usize,
92    cos: Tensor,
93    sin: Tensor,
94    sliding_window: usize,
95    sdpa_params: SdpaParams,
96    dtype: DType,
97}
98
99impl LayerWeights {
100    fn apply_rotary_emb(&self, xs: &Tensor, seqlen_offsets: &[usize]) -> Result<Tensor> {
101        let (_b_sz, _h, seq_len, _n_embd) = xs.dims4()?;
102        let mut outputs = Vec::new();
103        for (i, offset) in seqlen_offsets.iter().enumerate() {
104            let cos = self.cos.narrow(0, *offset, seq_len)?;
105            let sin = self.sin.narrow(0, *offset, seq_len)?;
106            outputs.push(candle_nn::rotary_emb::rope(
107                &xs.i(i)?.unsqueeze(0)?.contiguous()?,
108                &cos,
109                &sin,
110            )?);
111        }
112        Tensor::cat(&outputs, 0)
113    }
114
115    #[allow(clippy::too_many_arguments)]
116    fn forward_attn(
117        &self,
118        x: &Tensor,
119        mask: Option<&Tensor>,
120        seqlen_offsets: &[usize],
121        kv_cache: &mut Option<(Tensor, Tensor)>,
122        scalings: Option<Tensor>,
123        global_scaling_weight: f64,
124        is_scaling_pass: Option<f64>,
125        flash_params: &FlashParams,
126    ) -> Result<Tensor> {
127        let (b_sz, seq_len, n_embd) = x.dims3()?;
128        let qkv = self
129            .attn_qkv
130            .lora_forward(x, scalings.clone(), global_scaling_weight, is_scaling_pass)?
131            .to_dtype(self.dtype)?;
132
133        let query_pos = self.n_head * self.head_dim;
134        let q = qkv.narrow(D::Minus1, 0, query_pos)?;
135        let k = qkv.narrow(D::Minus1, query_pos, self.n_kv_head * self.head_dim)?;
136        let v = qkv.narrow(
137            D::Minus1,
138            query_pos + self.n_kv_head * self.head_dim,
139            self.n_kv_head * self.head_dim,
140        )?;
141
142        let (q, k, v) = if seq_len != 1 {
143            let q = q
144                .reshape((b_sz, seq_len, self.n_head, self.head_dim))?
145                .transpose(1, 2)?;
146            let k = k
147                .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
148                .transpose(1, 2)?;
149            let v = v
150                .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
151                .transpose(1, 2)?;
152            (q, k, v)
153        } else {
154            let q = q.reshape((b_sz, self.n_head, seq_len, self.head_dim))?;
155            let k = k.reshape((b_sz, self.n_kv_head, seq_len, self.head_dim))?;
156            let v = v.reshape((b_sz, self.n_kv_head, seq_len, self.head_dim))?;
157            (q, k, v)
158        };
159
160        let q = self.apply_rotary_emb(&q, seqlen_offsets)?.contiguous()?;
161        let k = self.apply_rotary_emb(&k, seqlen_offsets)?;
162
163        let (k, v, attn_mask) =
164            Cache::update_kv_cache_sliding_window(kv_cache, k, v, mask, Some(self.sliding_window))?;
165
166        let y = Sdpa.run_attention(
167            &q,
168            &k,
169            &v,
170            attn_mask.as_ref(),
171            Some(flash_params),
172            &self.sdpa_params,
173        )?;
174
175        let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
176        let y = self.attn_output.lora_forward(
177            &y.to_dtype(x.dtype())?,
178            scalings,
179            global_scaling_weight,
180            is_scaling_pass,
181        )?;
182        Ok(y)
183    }
184}
185
186pub struct ModelWeights {
187    tok_embeddings: Embedding,
188    layers: Vec<LayerWeights>,
189    output_norm: RmsNorm,
190    output: QLoraLinear,
191    mapper: Option<Box<dyn DeviceMapper + Send + Sync>>,
192    pub device: Device,
193    pub cache: EitherCache,
194    pub max_seq_len: usize,
195    xlora_classifier: Option<XLoraClassifier>,
196    dtype: DType,
197}
198
199fn precomput_freqs_cis(
200    head_dim: usize,
201    freq_base: f32,
202    device: &Device,
203    context_window: usize,
204    dtype: DType,
205) -> Result<(Tensor, Tensor)> {
206    let theta: Vec<_> = (0..head_dim)
207        .step_by(2)
208        .map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32))
209        .collect();
210    let theta = Tensor::new(theta.as_slice(), device)?;
211    let idx_theta = Tensor::arange(0, context_window as u32, device)?
212        .to_dtype(DType::F32)?
213        .reshape((context_window, 1))?
214        .matmul(&theta.reshape((1, theta.elem_count()))?)?;
215    let cos = idx_theta.cos()?.to_dtype(dtype)?;
216    let sin = idx_theta.sin()?.to_dtype(dtype)?;
217    Ok((cos, sin))
218}
219
220impl ModelConfig::FromAdapterGGUF for ModelWeights {
221    #[allow(clippy::too_many_arguments)]
222    fn from_gguf<R: std::io::Seek + std::io::Read>(
223        mut ct: Content<'_, R>,
224        device: &Device,
225        lora_config: &[((String, String), LoraConfig)],
226        vb: &ShardedVarBuilder,
227        ordering: &Ordering,
228        xlora_config: Option<XLoraConfig>,
229        mapper: Box<dyn DeviceMapper + Send + Sync>,
230        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
231        dtype: DType,
232    ) -> Result<Self> {
233        verify_sanity_adapters(ordering, &SUPPORTED_LAYERS)?;
234
235        // Parameter extraction from metadata.
236        let metadata = ContentMetadata {
237            path_prefix: "phi3",
238            metadata: ct.get_metadata(),
239        };
240        let PropsGGUF {
241            head_count,
242            head_count_kv,
243            block_count,
244            embedding_length,
245            i_size,
246            rope_dim,
247            rms_eps,
248            context_window,
249        } = PropsGGUF::try_from(metadata).or_else(|err| candle_core::bail!("{err}"))?;
250
251        let (cos, sin) = precomput_freqs_cis(rope_dim, 10_000., device, context_window, dtype)?;
252
253        let tok_embeddings = ct.tensor("token_embd.weight", device)?;
254        let tok_embeddings = tok_embeddings.dequantize(device)?;
255        let output_norm = rms_norm(ct.tensor("output_norm.weight", device)?, rms_eps)?;
256        let output = ct.tensor("output.weight", device)?;
257        let mut layers = Vec::with_capacity(block_count);
258
259        let mut count = 0;
260        for layer_idx in NiceProgressBar::<_, 'b'>(
261            0..block_count,
262            "Loading repeating layers",
263            &MultiProgress::new(),
264        ) {
265            let prefix = format!("blk.{layer_idx}");
266            let device = mapper.device_for(layer_idx, false).unwrap_or(device);
267            let ffn_up = ct.tensor(&format!("{prefix}.ffn_up.weight"), device)?;
268            let ffn_down = ct.tensor(&format!("{prefix}.ffn_down.weight"), device)?;
269            let cfg_up = get_lora_cfg(&ffn_up);
270            let cfg_down = get_lora_cfg(&ffn_down);
271            let mlp = Mlp {
272                ffn_up: QLoraLinear::new(
273                    QMatMul::from_qtensor(ffn_up)?,
274                    &cfg_up,
275                    lora_config,
276                    vb,
277                    ordering,
278                    format!("{prefix}.mlp.gate_up_proj"),
279                    &mut count,
280                    preload_adapters,
281                )?,
282                ffn_down: QLoraLinear::new(
283                    QMatMul::from_qtensor(ffn_down)?,
284                    &cfg_down,
285                    lora_config,
286                    vb,
287                    ordering,
288                    format!("{prefix}.mlp.down_proj"),
289                    &mut count,
290                    preload_adapters,
291                )?,
292                i_size,
293            };
294            let attn_norm = rms_norm(
295                ct.tensor(&format!("{prefix}.attn_norm.weight"), device)?,
296                rms_eps,
297            )?;
298            let ffn_norm = rms_norm(
299                ct.tensor(&format!("{prefix}.ffn_norm.weight"), device)?,
300                rms_eps,
301            )?;
302            let qkv = ct.tensor(&format!("{prefix}.attn_qkv.weight"), device)?;
303            let output = ct.tensor(&format!("{prefix}.attn_output.weight"), device)?;
304            let cfg_qkv = get_lora_cfg(&qkv);
305            let cfg_out = get_lora_cfg(&output);
306            let head_dim = embedding_length / head_count;
307            layers.push(LayerWeights {
308                attn_qkv: QLoraLinear::new(
309                    QMatMul::from_qtensor(qkv)?,
310                    &cfg_qkv,
311                    lora_config,
312                    vb,
313                    ordering,
314                    format!("{prefix}.self_attn.qkv_proj"),
315                    &mut count,
316                    preload_adapters,
317                )?,
318                attn_output: QLoraLinear::new(
319                    QMatMul::from_qtensor(output)?,
320                    &cfg_out,
321                    lora_config,
322                    vb,
323                    ordering,
324                    format!("{prefix}.self_attn.o_proj"),
325                    &mut count,
326                    preload_adapters,
327                )?,
328                attn_norm,
329                ffn_norm,
330                mlp,
331                n_head: head_count,
332                n_kv_head: head_count_kv,
333                head_dim: embedding_length / head_count,
334                cos: cos.to_device(device)?,
335                sin: sin.to_device(device)?,
336                sliding_window: context_window,
337                sdpa_params: SdpaParams {
338                    n_kv_groups: head_count / head_count_kv,
339                    softcap: None,
340                    softmax_scale: 1.0 / (head_dim as f32).sqrt(),
341                    sliding_window: Some(context_window),
342                },
343                dtype,
344            })
345        }
346        if xlora_config.is_none() {
347            // We are now a LoRA model so we must merge the weights
348            info!("Merging LoRA adapters.");
349            for layer in layers.iter_mut().tqdm() {
350                layer.attn_qkv.merge_weights()?;
351                layer.attn_output.merge_weights()?;
352                layer.mlp.ffn_down.merge_weights()?;
353                layer.mlp.ffn_up.merge_weights()?;
354            }
355        }
356        let output_cfg = get_lora_cfg(&output);
357        let output = QLoraLinear::new(
358            QMatMul::from_qtensor(output)?,
359            &output_cfg,
360            lora_config,
361            vb,
362            ordering,
363            "lm_head".to_string(),
364            &mut count,
365            preload_adapters,
366        )?;
367        if xlora_config.is_some() && output.is_lora() {
368            // This is why we can pass dummy values (..., None, 1.0, None)?
369            candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA.");
370        }
371        Ok(Self {
372            tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
373            layers,
374            output_norm,
375            output,
376            mapper: Some(mapper),
377            device: device.clone(),
378            cache: EitherCache::Full(Cache::new(block_count, true)),
379            max_seq_len: context_window,
380            xlora_classifier: xlora_config.map(|xlora_config| {
381                XLoraClassifier::new(xlora_config, count, lora_config.len(), vb.clone(), true)
382                    .unwrap()
383            }),
384            dtype,
385        })
386    }
387}
388
389impl ModelWeights {
390    #[allow(clippy::too_many_arguments)]
391    pub fn inner_forward(
392        &self,
393        input_ids: &Tensor,
394        seqlen_offsets: &[usize],
395        scalings: Option<Tensor>,
396        is_full_pass: bool,
397        no_kv_cache: bool,
398        is_scaling_pass: Option<f64>,
399        flash_params: &FlashParams,
400    ) -> Result<Tensor> {
401        let mut xs = self.tok_embeddings.forward(input_ids)?;
402        let mut cache = if is_full_pass {
403            if no_kv_cache {
404                let mut new_cache = Vec::new();
405                for _ in 0..self.cache.full().xlora_lock().len() {
406                    new_cache.push(None);
407                }
408
409                self.cache.full().xlora_lock().clone_from(&new_cache);
410            }
411            self.cache.full().xlora_lock()
412        } else {
413            self.cache.full().lock()
414        };
415        let mask = CausalMasker.make_sliding_window_causal_mask_matrix(
416            input_ids,
417            &*cache,
418            Some(self.max_seq_len),
419            self.dtype,
420            self.layers[0].n_head,
421        )?;
422        for (i, layer) in self.layers.iter().enumerate() {
423            if let Some(ref mapper) = self.mapper {
424                xs = mapper.map(xs, i)?;
425            }
426            let residual = &xs;
427            let ys = xs.apply(&layer.attn_norm)?;
428            let ys = layer.forward_attn(
429                &ys,
430                mask.as_ref()
431                    .map(|m| m.to_device(xs.device()).unwrap())
432                    .as_ref(),
433                seqlen_offsets,
434                &mut cache[i],
435                scalings.clone(),
436                self.xlora_classifier
437                    .as_ref()
438                    .map(|classifier| classifier.get_global_scaling_weight())
439                    .unwrap_or(1.0),
440                is_scaling_pass,
441                flash_params,
442            )?;
443            let ys = (ys + residual)?;
444            let residual = &ys;
445            let ys = ys.apply(&layer.ffn_norm)?;
446            let ys = layer.mlp.forward(
447                &ys,
448                scalings.clone(),
449                self.xlora_classifier
450                    .as_ref()
451                    .map(|classifier| classifier.get_global_scaling_weight())
452                    .unwrap_or(1.0),
453                is_scaling_pass,
454            )?;
455            xs = (ys + residual)?
456        }
457        let xs = xs.to_device(&self.device)?;
458        xs.apply(&self.output_norm)
459    }
460
461    #[allow(clippy::too_many_arguments)]
462    pub fn forward(
463        &self,
464        input_ids: &Tensor,
465        input_ids_full: &Tensor,
466        seqlen_offsets: &[usize],
467        seqlen_offsets_full: &[usize],
468        no_kv_cache: bool,
469        non_granular_state: &Option<NonGranularState>,
470        context_lens: Vec<(usize, usize)>,
471        flash_params: &FlashParams,
472        flash_params_full: &FlashParams,
473    ) -> Result<Tensor> {
474        if self.xlora_classifier.is_some() {
475            let scalings = self.get_scalings(
476                input_ids,
477                input_ids_full,
478                seqlen_offsets,
479                seqlen_offsets_full,
480                no_kv_cache,
481                non_granular_state,
482                &vec![usize::MAX; context_lens.len()],
483                flash_params,
484                flash_params_full,
485            )?;
486
487            if no_kv_cache {
488                extract_logits(
489                    &self.output.lora_forward(
490                        &self
491                            .inner_forward(
492                                input_ids_full,
493                                seqlen_offsets_full,
494                                Some(scalings),
495                                true,
496                                no_kv_cache,
497                                None,
498                                flash_params_full,
499                            )?
500                            .contiguous()?,
501                        None,
502                        1.0,
503                        None,
504                    )?,
505                    context_lens,
506                )
507            } else {
508                // is_full_pass=true is ok because no_kv_cache=false
509                extract_logits(
510                    &self.output.lora_forward(
511                        &self
512                            .inner_forward(
513                                input_ids,
514                                seqlen_offsets,
515                                Some(scalings),
516                                true,
517                                no_kv_cache,
518                                None,
519                                flash_params,
520                            )?
521                            .contiguous()?,
522                        None,
523                        1.0,
524                        None,
525                    )?,
526                    context_lens,
527                )
528            }
529        } else {
530            extract_logits(
531                &self.output.lora_forward(
532                    &self
533                        .inner_forward(
534                            input_ids,
535                            seqlen_offsets,
536                            None,
537                            false,
538                            no_kv_cache,
539                            None,
540                            flash_params,
541                        )?
542                        .contiguous()?,
543                    None,
544                    1.0,
545                    None,
546                )?,
547                context_lens,
548            )
549        }
550    }
551}
552
553impl ScalingsMaker for ModelWeights {
554    fn dtype(&self) -> DType {
555        DType::F32 // for dummy scalings
556    }
557    fn get_cache(&self) -> &EitherCache {
558        &self.cache
559    }
560    fn get_classifier(&self) -> &XLoraClassifier {
561        self.xlora_classifier.as_ref().unwrap()
562    }
563    fn forward(
564        &self,
565        input_ids: &Tensor,
566        seqlen_offsets: &[usize],
567        scalings: Tensor,
568        is_full_pass: bool,
569        no_kv_cache: bool,
570        is_scaling_pass: Option<f64>,
571        _context_lens: &[usize],
572        flash_params: &FlashParams,
573    ) -> Result<Tensor> {
574        self.inner_forward(
575            input_ids,
576            seqlen_offsets,
577            Some(scalings),
578            is_full_pass,
579            no_kv_cache,
580            is_scaling_pass,
581            flash_params,
582        )
583    }
584}