mistralrs_core/xlora_models/
quantized_phi3.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::collections::HashMap;
4
5use crate::attention::SdpaParams;
6use crate::device_map::DeviceMapper;
7use crate::gguf::Content;
8use crate::layers::CausalMasker;
9use crate::layers::RmsNorm;
10use crate::layers::Sdpa;
11use crate::lora::get_lora_cfg;
12use crate::lora::LinearLayerLike;
13use crate::lora::LoraConfig;
14use crate::lora::Merge;
15use crate::lora::Ordering;
16use crate::lora::QLoraLinear;
17use crate::pipeline::extract_logits;
18use crate::pipeline::text_models_inputs_processor::FlashParams;
19use crate::pipeline::EitherCache;
20use crate::utils::progress::NiceProgressBar;
21use candle_core::quantized::QMatMul;
22use candle_core::quantized::QTensor;
23use candle_core::{DType, Device, IndexOp, Module, Result, Tensor, D};
24use candle_nn::Embedding;
25use indicatif::MultiProgress;
26use mistralrs_quant::ShardedVarBuilder;
27use tqdm::Iter;
28use tracing::info;
29
30use super::classifier::XLoraClassifier;
31use super::verify_sanity_adapters;
32use super::Cache;
33use super::NonGranularState;
34use super::ScalingsMaker;
35use super::XLoraConfig;
36use crate::models::quantized_phi3::PropsGGUF;
37use crate::utils::gguf_metadata::ContentMetadata;
38use crate::utils::model_config as ModelConfig;
39
40const SUPPORTED_LAYERS: [&str; 5] = [
41    "self_attn.qkv_proj",
42    "self_attn.o_proj",
43    "mlp.gate_up_proj",
44    "mlp.down_proj",
45    "lm_head",
46];
47
48#[derive(Debug)]
49struct Mlp {
50    ffn_up: QLoraLinear,
51    ffn_down: QLoraLinear,
52    i_size: usize,
53}
54
55impl Mlp {
56    fn forward(
57        &self,
58        xs: &Tensor,
59        scalings: Option<Tensor>,
60        global_scaling_weight: f64,
61        is_scaling_pass: Option<f64>,
62    ) -> Result<Tensor> {
63        let up_states = self.ffn_up.lora_forward(
64            xs,
65            scalings.clone(),
66            global_scaling_weight,
67            is_scaling_pass,
68        )?;
69        let gate = up_states.narrow(D::Minus1, 0, self.i_size)?;
70        let up_states = up_states.narrow(D::Minus1, self.i_size, self.i_size)?;
71        let up_states = (up_states * gate.silu()?)?;
72        self.ffn_down
73            .lora_forward(&up_states, scalings, global_scaling_weight, is_scaling_pass)
74    }
75}
76
77fn rms_norm(w: QTensor, eps: f64) -> Result<RmsNorm> {
78    let w = w.dequantize(&w.device())?;
79    let rms = RmsNorm::from_w(w, eps)?;
80    Ok(rms)
81}
82
83struct LayerWeights {
84    attn_qkv: QLoraLinear,
85    attn_output: QLoraLinear,
86    attn_norm: RmsNorm,
87    ffn_norm: RmsNorm,
88    mlp: Mlp,
89    n_head: usize,
90    n_kv_head: usize,
91    head_dim: usize,
92    cos: Tensor,
93    sin: Tensor,
94    sliding_window: usize,
95    sdpa_params: SdpaParams,
96    dtype: DType,
97}
98
99impl LayerWeights {
100    fn apply_rotary_emb(&self, xs: &Tensor, seqlen_offsets: &[usize]) -> Result<Tensor> {
101        let (_b_sz, _h, seq_len, _n_embd) = xs.dims4()?;
102        let mut outputs = Vec::new();
103        for (i, offset) in seqlen_offsets.iter().enumerate() {
104            let cos = self.cos.narrow(0, *offset, seq_len)?;
105            let sin = self.sin.narrow(0, *offset, seq_len)?;
106            outputs.push(candle_nn::rotary_emb::rope(
107                &xs.i(i)?.unsqueeze(0)?.contiguous()?,
108                &cos,
109                &sin,
110            )?);
111        }
112        Tensor::cat(&outputs, 0)
113    }
114
115    #[allow(clippy::too_many_arguments)]
116    fn forward_attn(
117        &self,
118        x: &Tensor,
119        mask: Option<&Tensor>,
120        seqlen_offsets: &[usize],
121        kv_cache: &mut Option<(Tensor, Tensor)>,
122        scalings: Option<Tensor>,
123        global_scaling_weight: f64,
124        is_scaling_pass: Option<f64>,
125        flash_params: &FlashParams,
126    ) -> Result<Tensor> {
127        let (b_sz, seq_len, n_embd) = x.dims3()?;
128        let qkv = self
129            .attn_qkv
130            .lora_forward(x, scalings.clone(), global_scaling_weight, is_scaling_pass)?
131            .to_dtype(self.dtype)?;
132
133        let query_pos = self.n_head * self.head_dim;
134        let q = qkv.narrow(D::Minus1, 0, query_pos)?;
135        let k = qkv.narrow(D::Minus1, query_pos, self.n_kv_head * self.head_dim)?;
136        let v = qkv.narrow(
137            D::Minus1,
138            query_pos + self.n_kv_head * self.head_dim,
139            self.n_kv_head * self.head_dim,
140        )?;
141
142        let (q, k, v) = if seq_len != 1 {
143            let q = q
144                .reshape((b_sz, seq_len, self.n_head, self.head_dim))?
145                .transpose(1, 2)?;
146            let k = k
147                .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
148                .transpose(1, 2)?;
149            let v = v
150                .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
151                .transpose(1, 2)?;
152            (q, k, v)
153        } else {
154            let q = q.reshape((b_sz, self.n_head, seq_len, self.head_dim))?;
155            let k = k.reshape((b_sz, self.n_kv_head, seq_len, self.head_dim))?;
156            let v = v.reshape((b_sz, self.n_kv_head, seq_len, self.head_dim))?;
157            (q, k, v)
158        };
159
160        let q = self.apply_rotary_emb(&q, seqlen_offsets)?.contiguous()?;
161        let k = self.apply_rotary_emb(&k, seqlen_offsets)?;
162
163        let (k, v, attn_mask) = Cache::update_kv_cache_sliding_window(
164            kv_cache,
165            k,
166            v,
167            mask,
168            Some(self.sliding_window),
169            true,
170        )?;
171
172        let y = Sdpa.run_attention(
173            &q,
174            &k,
175            &v,
176            attn_mask.as_ref(),
177            Some(flash_params),
178            &self.sdpa_params,
179        )?;
180
181        let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
182        let y = self.attn_output.lora_forward(
183            &y.to_dtype(x.dtype())?,
184            scalings,
185            global_scaling_weight,
186            is_scaling_pass,
187        )?;
188        Ok(y)
189    }
190}
191
192pub struct ModelWeights {
193    tok_embeddings: Embedding,
194    layers: Vec<LayerWeights>,
195    output_norm: RmsNorm,
196    output: QLoraLinear,
197    mapper: Option<Box<dyn DeviceMapper + Send + Sync>>,
198    pub device: Device,
199    pub cache: EitherCache,
200    pub max_seq_len: usize,
201    xlora_classifier: Option<XLoraClassifier>,
202    dtype: DType,
203}
204
205fn precomput_freqs_cis(
206    head_dim: usize,
207    freq_base: f32,
208    device: &Device,
209    context_window: usize,
210    dtype: DType,
211) -> Result<(Tensor, Tensor)> {
212    let theta: Vec<_> = (0..head_dim)
213        .step_by(2)
214        .map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32))
215        .collect();
216    let theta = Tensor::new(theta.as_slice(), device)?;
217    let idx_theta = Tensor::arange(0, context_window as u32, device)?
218        .to_dtype(DType::F32)?
219        .reshape((context_window, 1))?
220        .matmul(&theta.reshape((1, theta.elem_count()))?)?;
221    let cos = idx_theta.cos()?.to_dtype(dtype)?;
222    let sin = idx_theta.sin()?.to_dtype(dtype)?;
223    Ok((cos, sin))
224}
225
226impl ModelConfig::FromAdapterGGUF for ModelWeights {
227    #[allow(clippy::too_many_arguments)]
228    fn from_gguf<R: std::io::Seek + std::io::Read>(
229        mut ct: Content<'_, R>,
230        device: &Device,
231        lora_config: &[((String, String), LoraConfig)],
232        vb: &ShardedVarBuilder,
233        ordering: &Ordering,
234        xlora_config: Option<XLoraConfig>,
235        mapper: Box<dyn DeviceMapper + Send + Sync>,
236        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
237        dtype: DType,
238    ) -> Result<Self> {
239        verify_sanity_adapters(ordering, &SUPPORTED_LAYERS)?;
240
241        // Parameter extraction from metadata.
242        let metadata = ContentMetadata {
243            path_prefix: "phi3",
244            metadata: ct.get_metadata(),
245        };
246        let PropsGGUF {
247            head_count,
248            head_count_kv,
249            block_count,
250            embedding_length,
251            i_size,
252            rope_dim,
253            rms_eps,
254            context_window,
255        } = PropsGGUF::try_from(metadata).or_else(|err| candle_core::bail!("{err}"))?;
256
257        let (cos, sin) = precomput_freqs_cis(rope_dim, 10_000., device, context_window, dtype)?;
258
259        let tok_embeddings = ct.tensor("token_embd.weight", device)?;
260        let tok_embeddings = tok_embeddings.dequantize(device)?;
261        let output_norm = rms_norm(ct.tensor("output_norm.weight", device)?, rms_eps)?;
262        let output = ct.tensor("output.weight", device)?;
263        let mut layers = Vec::with_capacity(block_count);
264
265        let mut count = 0;
266        for layer_idx in NiceProgressBar::<_, 'b'>(
267            0..block_count,
268            "Loading repeating layers",
269            &MultiProgress::new(),
270        ) {
271            let prefix = format!("blk.{layer_idx}");
272            let device = mapper.device_for(layer_idx, false).unwrap_or(device);
273            let ffn_up = ct.tensor(&format!("{prefix}.ffn_up.weight"), device)?;
274            let ffn_down = ct.tensor(&format!("{prefix}.ffn_down.weight"), device)?;
275            let cfg_up = get_lora_cfg(&ffn_up);
276            let cfg_down = get_lora_cfg(&ffn_down);
277            let mlp = Mlp {
278                ffn_up: QLoraLinear::new(
279                    QMatMul::from_qtensor(ffn_up)?,
280                    &cfg_up,
281                    lora_config,
282                    vb,
283                    ordering,
284                    format!("{prefix}.mlp.gate_up_proj"),
285                    &mut count,
286                    preload_adapters,
287                )?,
288                ffn_down: QLoraLinear::new(
289                    QMatMul::from_qtensor(ffn_down)?,
290                    &cfg_down,
291                    lora_config,
292                    vb,
293                    ordering,
294                    format!("{prefix}.mlp.down_proj"),
295                    &mut count,
296                    preload_adapters,
297                )?,
298                i_size,
299            };
300            let attn_norm = rms_norm(
301                ct.tensor(&format!("{prefix}.attn_norm.weight"), device)?,
302                rms_eps,
303            )?;
304            let ffn_norm = rms_norm(
305                ct.tensor(&format!("{prefix}.ffn_norm.weight"), device)?,
306                rms_eps,
307            )?;
308            let qkv = ct.tensor(&format!("{prefix}.attn_qkv.weight"), device)?;
309            let output = ct.tensor(&format!("{prefix}.attn_output.weight"), device)?;
310            let cfg_qkv = get_lora_cfg(&qkv);
311            let cfg_out = get_lora_cfg(&output);
312            let head_dim = embedding_length / head_count;
313            layers.push(LayerWeights {
314                attn_qkv: QLoraLinear::new(
315                    QMatMul::from_qtensor(qkv)?,
316                    &cfg_qkv,
317                    lora_config,
318                    vb,
319                    ordering,
320                    format!("{prefix}.self_attn.qkv_proj"),
321                    &mut count,
322                    preload_adapters,
323                )?,
324                attn_output: QLoraLinear::new(
325                    QMatMul::from_qtensor(output)?,
326                    &cfg_out,
327                    lora_config,
328                    vb,
329                    ordering,
330                    format!("{prefix}.self_attn.o_proj"),
331                    &mut count,
332                    preload_adapters,
333                )?,
334                attn_norm,
335                ffn_norm,
336                mlp,
337                n_head: head_count,
338                n_kv_head: head_count_kv,
339                head_dim: embedding_length / head_count,
340                cos: cos.to_device(device)?,
341                sin: sin.to_device(device)?,
342                sliding_window: context_window,
343                sdpa_params: SdpaParams {
344                    n_kv_groups: head_count / head_count_kv,
345                    use_flash_attn: false,
346                    softcap: None,
347                    softmax_scale: 1.0 / (head_dim as f32).sqrt(),
348                    sliding_window: Some(context_window),
349                },
350                dtype,
351            })
352        }
353        if xlora_config.is_none() {
354            // We are now a LoRA model so we must merge the weights
355            info!("Merging LoRA adapters.");
356            for layer in layers.iter_mut().tqdm() {
357                layer.attn_qkv.merge_weights()?;
358                layer.attn_output.merge_weights()?;
359                layer.mlp.ffn_down.merge_weights()?;
360                layer.mlp.ffn_up.merge_weights()?;
361            }
362        }
363        let output_cfg = get_lora_cfg(&output);
364        let output = QLoraLinear::new(
365            QMatMul::from_qtensor(output)?,
366            &output_cfg,
367            lora_config,
368            vb,
369            ordering,
370            "lm_head".to_string(),
371            &mut count,
372            preload_adapters,
373        )?;
374        if xlora_config.is_some() && output.is_lora() {
375            // This is why we can pass dummy values (..., None, 1.0, None)?
376            candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA.");
377        }
378        Ok(Self {
379            tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
380            layers,
381            output_norm,
382            output,
383            mapper: Some(mapper),
384            device: device.clone(),
385            cache: EitherCache::Full(Cache::new(block_count, true)),
386            max_seq_len: context_window,
387            xlora_classifier: xlora_config.map(|xlora_config| {
388                XLoraClassifier::new(xlora_config, count, lora_config.len(), vb.clone(), true)
389                    .unwrap()
390            }),
391            dtype,
392        })
393    }
394}
395
396impl ModelWeights {
397    #[allow(clippy::too_many_arguments)]
398    pub fn inner_forward(
399        &self,
400        input_ids: &Tensor,
401        seqlen_offsets: &[usize],
402        scalings: Option<Tensor>,
403        is_full_pass: bool,
404        no_kv_cache: bool,
405        is_scaling_pass: Option<f64>,
406        flash_params: &FlashParams,
407    ) -> Result<Tensor> {
408        let mut xs = self.tok_embeddings.forward(input_ids)?;
409        let mut cache = if is_full_pass {
410            if no_kv_cache {
411                let mut new_cache = Vec::new();
412                for _ in 0..self.cache.full().xlora_lock().len() {
413                    new_cache.push(None);
414                }
415
416                self.cache.full().xlora_lock().clone_from(&new_cache);
417            }
418            self.cache.full().xlora_lock()
419        } else {
420            self.cache.full().lock()
421        };
422        let mask = CausalMasker.make_sliding_window_causal_mask_matrix(
423            input_ids,
424            &*cache,
425            Some(self.max_seq_len),
426            self.dtype,
427            self.layers[0].n_head,
428        )?;
429        for (i, layer) in self.layers.iter().enumerate() {
430            if let Some(ref mapper) = self.mapper {
431                xs = mapper.map(xs, i)?;
432            }
433            let residual = &xs;
434            let ys = xs.apply(&layer.attn_norm)?;
435            let ys = layer.forward_attn(
436                &ys,
437                mask.as_ref()
438                    .map(|m| m.to_device(xs.device()).unwrap())
439                    .as_ref(),
440                seqlen_offsets,
441                &mut cache[i],
442                scalings.clone(),
443                self.xlora_classifier
444                    .as_ref()
445                    .map(|classifier| classifier.get_global_scaling_weight())
446                    .unwrap_or(1.0),
447                is_scaling_pass,
448                flash_params,
449            )?;
450            let ys = (ys + residual)?;
451            let residual = &ys;
452            let ys = ys.apply(&layer.ffn_norm)?;
453            let ys = layer.mlp.forward(
454                &ys,
455                scalings.clone(),
456                self.xlora_classifier
457                    .as_ref()
458                    .map(|classifier| classifier.get_global_scaling_weight())
459                    .unwrap_or(1.0),
460                is_scaling_pass,
461            )?;
462            xs = (ys + residual)?
463        }
464        let xs = xs.to_device(&self.device)?;
465        xs.apply(&self.output_norm)
466    }
467
468    #[allow(clippy::too_many_arguments)]
469    pub fn forward(
470        &self,
471        input_ids: &Tensor,
472        input_ids_full: &Tensor,
473        seqlen_offsets: &[usize],
474        seqlen_offsets_full: &[usize],
475        no_kv_cache: bool,
476        non_granular_state: &Option<NonGranularState>,
477        context_lens: Vec<(usize, usize)>,
478        flash_params: &FlashParams,
479        flash_params_full: &FlashParams,
480    ) -> Result<Tensor> {
481        if self.xlora_classifier.is_some() {
482            let scalings = self.get_scalings(
483                input_ids,
484                input_ids_full,
485                seqlen_offsets,
486                seqlen_offsets_full,
487                no_kv_cache,
488                non_granular_state,
489                &vec![usize::MAX; context_lens.len()],
490                flash_params,
491                flash_params_full,
492            )?;
493
494            if no_kv_cache {
495                extract_logits(
496                    &self.output.lora_forward(
497                        &self
498                            .inner_forward(
499                                input_ids_full,
500                                seqlen_offsets_full,
501                                Some(scalings),
502                                true,
503                                no_kv_cache,
504                                None,
505                                flash_params_full,
506                            )?
507                            .contiguous()?,
508                        None,
509                        1.0,
510                        None,
511                    )?,
512                    context_lens,
513                )
514            } else {
515                // is_full_pass=true is ok because no_kv_cache=false
516                extract_logits(
517                    &self.output.lora_forward(
518                        &self
519                            .inner_forward(
520                                input_ids,
521                                seqlen_offsets,
522                                Some(scalings),
523                                true,
524                                no_kv_cache,
525                                None,
526                                flash_params,
527                            )?
528                            .contiguous()?,
529                        None,
530                        1.0,
531                        None,
532                    )?,
533                    context_lens,
534                )
535            }
536        } else {
537            extract_logits(
538                &self.output.lora_forward(
539                    &self
540                        .inner_forward(
541                            input_ids,
542                            seqlen_offsets,
543                            None,
544                            false,
545                            no_kv_cache,
546                            None,
547                            flash_params,
548                        )?
549                        .contiguous()?,
550                    None,
551                    1.0,
552                    None,
553                )?,
554                context_lens,
555            )
556        }
557    }
558}
559
560impl ScalingsMaker for ModelWeights {
561    fn dtype(&self) -> DType {
562        DType::F32 // for dummy scalings
563    }
564    fn get_cache(&self) -> &EitherCache {
565        &self.cache
566    }
567    fn get_classifier(&self) -> &XLoraClassifier {
568        self.xlora_classifier.as_ref().unwrap()
569    }
570    fn forward(
571        &self,
572        input_ids: &Tensor,
573        seqlen_offsets: &[usize],
574        scalings: Tensor,
575        is_full_pass: bool,
576        no_kv_cache: bool,
577        is_scaling_pass: Option<f64>,
578        _context_lens: &[usize],
579        flash_params: &FlashParams,
580    ) -> Result<Tensor> {
581        self.inner_forward(
582            input_ids,
583            seqlen_offsets,
584            Some(scalings),
585            is_full_pass,
586            no_kv_cache,
587            is_scaling_pass,
588            flash_params,
589        )
590    }
591}