mistralrs_core/xlora_models/
quantized_phi3.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::collections::HashMap;
4
5use crate::attention::SdpaParams;
6use crate::device_map::DeviceMapper;
7use crate::gguf::Content;
8use crate::layers::CausalMasker;
9use crate::layers::RmsNorm;
10use crate::layers::Sdpa;
11use crate::lora::get_lora_cfg;
12use crate::lora::AdapterSwapper;
13use crate::lora::LinearLayerLike;
14use crate::lora::LoraConfig;
15use crate::lora::Merge;
16use crate::lora::Ordering;
17use crate::lora::QLoraLinear;
18use crate::pipeline::extract_logits;
19use crate::pipeline::text_models_inputs_processor::FlashParams;
20use crate::pipeline::EitherCache;
21use crate::utils::progress::NiceProgressBar;
22use candle_core::quantized::QMatMul;
23use candle_core::quantized::QTensor;
24use candle_core::{DType, Device, IndexOp, Module, Result, Tensor, D};
25use candle_nn::Embedding;
26use indicatif::MultiProgress;
27use mistralrs_quant::ShardedVarBuilder;
28use tqdm::Iter;
29use tracing::info;
30
31use super::classifier::XLoraClassifier;
32use super::verify_sanity_adapters;
33use super::Cache;
34use super::NonGranularState;
35use super::ScalingsMaker;
36use super::XLoraConfig;
37use crate::models::quantized_phi3::PropsGGUF;
38use crate::utils::gguf_metadata::ContentMetadata;
39use crate::utils::model_config as ModelConfig;
40
41const SUPPORTED_LAYERS: [&str; 5] = [
42    "self_attn.qkv_proj",
43    "self_attn.o_proj",
44    "mlp.gate_up_proj",
45    "mlp.down_proj",
46    "lm_head",
47];
48
49#[derive(Debug)]
50struct Mlp {
51    ffn_up: QLoraLinear,
52    ffn_down: QLoraLinear,
53    i_size: usize,
54}
55
56impl Mlp {
57    fn forward(
58        &self,
59        xs: &Tensor,
60        scalings: Option<Tensor>,
61        global_scaling_weight: f64,
62        is_scaling_pass: Option<f64>,
63    ) -> Result<Tensor> {
64        let up_states = self.ffn_up.lora_forward(
65            xs,
66            scalings.clone(),
67            global_scaling_weight,
68            is_scaling_pass,
69        )?;
70        let gate = up_states.narrow(D::Minus1, 0, self.i_size)?;
71        let up_states = up_states.narrow(D::Minus1, self.i_size, self.i_size)?;
72        let up_states = (up_states * gate.silu()?)?;
73        self.ffn_down
74            .lora_forward(&up_states, scalings, global_scaling_weight, is_scaling_pass)
75    }
76}
77
78fn rms_norm(w: QTensor, eps: f64) -> Result<RmsNorm> {
79    let w = w.dequantize(&w.device())?;
80    let rms = RmsNorm::from_w(w, eps)?;
81    Ok(rms)
82}
83
84struct LayerWeights {
85    attn_qkv: QLoraLinear,
86    attn_output: QLoraLinear,
87    attn_norm: RmsNorm,
88    ffn_norm: RmsNorm,
89    mlp: Mlp,
90    n_head: usize,
91    n_kv_head: usize,
92    head_dim: usize,
93    cos: Tensor,
94    sin: Tensor,
95    sliding_window: usize,
96    sdpa_params: SdpaParams,
97    dtype: DType,
98}
99
100impl LayerWeights {
101    fn apply_rotary_emb(&self, xs: &Tensor, seqlen_offsets: &[usize]) -> Result<Tensor> {
102        let (_b_sz, _h, seq_len, _n_embd) = xs.dims4()?;
103        let mut outputs = Vec::new();
104        for (i, offset) in seqlen_offsets.iter().enumerate() {
105            let cos = self.cos.narrow(0, *offset, seq_len)?;
106            let sin = self.sin.narrow(0, *offset, seq_len)?;
107            outputs.push(candle_nn::rotary_emb::rope(
108                &xs.i(i)?.unsqueeze(0)?.contiguous()?,
109                &cos,
110                &sin,
111            )?);
112        }
113        Tensor::cat(&outputs, 0)
114    }
115
116    #[allow(clippy::too_many_arguments)]
117    fn forward_attn(
118        &self,
119        x: &Tensor,
120        mask: Option<&Tensor>,
121        seqlen_offsets: &[usize],
122        kv_cache: &mut Option<(Tensor, Tensor)>,
123        scalings: Option<Tensor>,
124        global_scaling_weight: f64,
125        is_scaling_pass: Option<f64>,
126        flash_params: &FlashParams,
127    ) -> Result<Tensor> {
128        let (b_sz, seq_len, n_embd) = x.dims3()?;
129        let qkv = self
130            .attn_qkv
131            .lora_forward(x, scalings.clone(), global_scaling_weight, is_scaling_pass)?
132            .to_dtype(self.dtype)?;
133
134        let query_pos = self.n_head * self.head_dim;
135        let q = qkv.narrow(D::Minus1, 0, query_pos)?;
136        let k = qkv.narrow(D::Minus1, query_pos, self.n_kv_head * self.head_dim)?;
137        let v = qkv.narrow(
138            D::Minus1,
139            query_pos + self.n_kv_head * self.head_dim,
140            self.n_kv_head * self.head_dim,
141        )?;
142
143        let (q, k, v) = if seq_len != 1 {
144            let q = q
145                .reshape((b_sz, seq_len, self.n_head, self.head_dim))?
146                .transpose(1, 2)?;
147            let k = k
148                .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
149                .transpose(1, 2)?;
150            let v = v
151                .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
152                .transpose(1, 2)?;
153            (q, k, v)
154        } else {
155            let q = q.reshape((b_sz, self.n_head, seq_len, self.head_dim))?;
156            let k = k.reshape((b_sz, self.n_kv_head, seq_len, self.head_dim))?;
157            let v = v.reshape((b_sz, self.n_kv_head, seq_len, self.head_dim))?;
158            (q, k, v)
159        };
160
161        let q = self.apply_rotary_emb(&q, seqlen_offsets)?.contiguous()?;
162        let k = self.apply_rotary_emb(&k, seqlen_offsets)?;
163
164        let (k, v, attn_mask) = Cache::update_kv_cache_sliding_window(
165            kv_cache,
166            k,
167            v,
168            mask,
169            Some(self.sliding_window),
170            true,
171        )?;
172
173        let y = Sdpa.run_attention(
174            &q,
175            &k,
176            &v,
177            attn_mask.as_ref(),
178            Some(flash_params),
179            &self.sdpa_params,
180        )?;
181
182        let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
183        let y = self.attn_output.lora_forward(
184            &y.to_dtype(x.dtype())?,
185            scalings,
186            global_scaling_weight,
187            is_scaling_pass,
188        )?;
189        Ok(y)
190    }
191}
192
193pub struct ModelWeights {
194    tok_embeddings: Embedding,
195    layers: Vec<LayerWeights>,
196    output_norm: RmsNorm,
197    output: QLoraLinear,
198    mapper: Option<Box<dyn DeviceMapper + Send + Sync>>,
199    pub device: Device,
200    pub cache: EitherCache,
201    pub max_seq_len: usize,
202    xlora_classifier: Option<XLoraClassifier>,
203    dtype: DType,
204}
205
206fn precomput_freqs_cis(
207    head_dim: usize,
208    freq_base: f32,
209    device: &Device,
210    context_window: usize,
211    dtype: DType,
212) -> Result<(Tensor, Tensor)> {
213    let theta: Vec<_> = (0..head_dim)
214        .step_by(2)
215        .map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32))
216        .collect();
217    let theta = Tensor::new(theta.as_slice(), device)?;
218    let idx_theta = Tensor::arange(0, context_window as u32, device)?
219        .to_dtype(DType::F32)?
220        .reshape((context_window, 1))?
221        .matmul(&theta.reshape((1, theta.elem_count()))?)?;
222    let cos = idx_theta.cos()?.to_dtype(dtype)?;
223    let sin = idx_theta.sin()?.to_dtype(dtype)?;
224    Ok((cos, sin))
225}
226
227impl ModelConfig::FromAdapterGGUF for ModelWeights {
228    #[allow(clippy::too_many_arguments)]
229    fn from_gguf<R: std::io::Seek + std::io::Read>(
230        mut ct: Content<'_, R>,
231        device: &Device,
232        lora_config: &[((String, String), LoraConfig)],
233        vb: &ShardedVarBuilder,
234        ordering: &Ordering,
235        xlora_config: Option<XLoraConfig>,
236        mapper: Box<dyn DeviceMapper + Send + Sync>,
237        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
238        dtype: DType,
239    ) -> Result<Self> {
240        verify_sanity_adapters(ordering, &SUPPORTED_LAYERS)?;
241
242        // Parameter extraction from metadata.
243        let metadata = ContentMetadata {
244            path_prefix: "phi3",
245            metadata: ct.get_metadata(),
246        };
247        let PropsGGUF {
248            head_count,
249            head_count_kv,
250            block_count,
251            embedding_length,
252            i_size,
253            rope_dim,
254            rms_eps,
255            context_window,
256        } = PropsGGUF::try_from(metadata).or_else(|err| candle_core::bail!("{err}"))?;
257
258        let (cos, sin) = precomput_freqs_cis(rope_dim, 10_000., device, context_window, dtype)?;
259
260        let tok_embeddings = ct.tensor("token_embd.weight", device)?;
261        let tok_embeddings = tok_embeddings.dequantize(device)?;
262        let output_norm = rms_norm(ct.tensor("output_norm.weight", device)?, rms_eps)?;
263        let output = ct.tensor("output.weight", device)?;
264        let mut layers = Vec::with_capacity(block_count);
265
266        let mut count = 0;
267        for layer_idx in NiceProgressBar::<_, 'b'>(
268            0..block_count,
269            "Loading repeating layers",
270            &MultiProgress::new(),
271        ) {
272            let prefix = format!("blk.{layer_idx}");
273            let device = mapper.device_for(layer_idx, false).unwrap_or(device);
274            let ffn_up = ct.tensor(&format!("{prefix}.ffn_up.weight"), device)?;
275            let ffn_down = ct.tensor(&format!("{prefix}.ffn_down.weight"), device)?;
276            let cfg_up = get_lora_cfg(&ffn_up);
277            let cfg_down = get_lora_cfg(&ffn_down);
278            let mlp = Mlp {
279                ffn_up: QLoraLinear::new(
280                    QMatMul::from_qtensor(ffn_up)?,
281                    &cfg_up,
282                    lora_config,
283                    vb,
284                    ordering,
285                    format!("{prefix}.mlp.gate_up_proj"),
286                    &mut count,
287                    preload_adapters,
288                )?,
289                ffn_down: QLoraLinear::new(
290                    QMatMul::from_qtensor(ffn_down)?,
291                    &cfg_down,
292                    lora_config,
293                    vb,
294                    ordering,
295                    format!("{prefix}.mlp.down_proj"),
296                    &mut count,
297                    preload_adapters,
298                )?,
299                i_size,
300            };
301            let attn_norm = rms_norm(
302                ct.tensor(&format!("{prefix}.attn_norm.weight"), device)?,
303                rms_eps,
304            )?;
305            let ffn_norm = rms_norm(
306                ct.tensor(&format!("{prefix}.ffn_norm.weight"), device)?,
307                rms_eps,
308            )?;
309            let qkv = ct.tensor(&format!("{prefix}.attn_qkv.weight"), device)?;
310            let output = ct.tensor(&format!("{prefix}.attn_output.weight"), device)?;
311            let cfg_qkv = get_lora_cfg(&qkv);
312            let cfg_out = get_lora_cfg(&output);
313            let head_dim = embedding_length / head_count;
314            layers.push(LayerWeights {
315                attn_qkv: QLoraLinear::new(
316                    QMatMul::from_qtensor(qkv)?,
317                    &cfg_qkv,
318                    lora_config,
319                    vb,
320                    ordering,
321                    format!("{prefix}.self_attn.qkv_proj"),
322                    &mut count,
323                    preload_adapters,
324                )?,
325                attn_output: QLoraLinear::new(
326                    QMatMul::from_qtensor(output)?,
327                    &cfg_out,
328                    lora_config,
329                    vb,
330                    ordering,
331                    format!("{prefix}.self_attn.o_proj"),
332                    &mut count,
333                    preload_adapters,
334                )?,
335                attn_norm,
336                ffn_norm,
337                mlp,
338                n_head: head_count,
339                n_kv_head: head_count_kv,
340                head_dim: embedding_length / head_count,
341                cos: cos.to_device(device)?,
342                sin: sin.to_device(device)?,
343                sliding_window: context_window,
344                sdpa_params: SdpaParams {
345                    n_kv_groups: head_count / head_count_kv,
346                    use_flash_attn: false,
347                    softcap: None,
348                    softmax_scale: 1.0 / (head_dim as f32).sqrt(),
349                    sliding_window: Some(context_window),
350                },
351                dtype,
352            })
353        }
354        if xlora_config.is_none() {
355            // We are now a LoRA model so we must merge the weights
356            info!("Merging LoRA adapters.");
357            for layer in layers.iter_mut().tqdm() {
358                layer.attn_qkv.merge_weights()?;
359                layer.attn_output.merge_weights()?;
360                layer.mlp.ffn_down.merge_weights()?;
361                layer.mlp.ffn_up.merge_weights()?;
362            }
363        }
364        let output_cfg = get_lora_cfg(&output);
365        let output = QLoraLinear::new(
366            QMatMul::from_qtensor(output)?,
367            &output_cfg,
368            lora_config,
369            vb,
370            ordering,
371            "lm_head".to_string(),
372            &mut count,
373            preload_adapters,
374        )?;
375        if xlora_config.is_some() && output.is_lora() {
376            // This is why we can pass dummy values (..., None, 1.0, None)?
377            candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA.");
378        }
379        Ok(Self {
380            tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
381            layers,
382            output_norm,
383            output,
384            mapper: Some(mapper),
385            device: device.clone(),
386            cache: EitherCache::Full(Cache::new(block_count, true)),
387            max_seq_len: context_window,
388            xlora_classifier: xlora_config.map(|xlora_config| {
389                XLoraClassifier::new(xlora_config, count, lora_config.len(), vb.clone(), true)
390                    .unwrap()
391            }),
392            dtype,
393        })
394    }
395}
396
397impl ModelWeights {
398    pub fn activate_adapters(&mut self, adapter_names: Vec<String>) -> Result<usize> {
399        if self.xlora_classifier.is_some() {
400            candle_core::bail!("Adapter activation is not supported for X-LoRA models as the adapter set must remain the same.");
401        }
402        let mut sum = 0;
403        for layer in self.layers.iter_mut() {
404            sum += layer.attn_qkv.activate(&adapter_names)?;
405            sum += layer.attn_output.activate(&adapter_names)?;
406            sum += layer.mlp.ffn_down.activate(&adapter_names)?;
407            sum += layer.mlp.ffn_up.activate(&adapter_names)?;
408        }
409        Ok(sum)
410    }
411
412    #[allow(clippy::too_many_arguments)]
413    pub fn inner_forward(
414        &self,
415        input_ids: &Tensor,
416        seqlen_offsets: &[usize],
417        scalings: Option<Tensor>,
418        is_full_pass: bool,
419        no_kv_cache: bool,
420        is_scaling_pass: Option<f64>,
421        flash_params: &FlashParams,
422    ) -> Result<Tensor> {
423        let mut xs = self.tok_embeddings.forward(input_ids)?;
424        let mut cache = if is_full_pass {
425            if no_kv_cache {
426                let mut new_cache = Vec::new();
427                for _ in 0..self.cache.full().xlora_lock().len() {
428                    new_cache.push(None);
429                }
430
431                self.cache.full().xlora_lock().clone_from(&new_cache);
432            }
433            self.cache.full().xlora_lock()
434        } else {
435            self.cache.full().lock()
436        };
437        let mask = CausalMasker.make_sliding_window_causal_mask_matrix(
438            input_ids,
439            &*cache,
440            Some(self.max_seq_len),
441            self.dtype,
442            self.layers[0].n_head,
443        )?;
444        for (i, layer) in self.layers.iter().enumerate() {
445            if let Some(ref mapper) = self.mapper {
446                xs = mapper.map(xs, i)?;
447            }
448            let residual = &xs;
449            let ys = xs.apply(&layer.attn_norm)?;
450            let ys = layer.forward_attn(
451                &ys,
452                mask.as_ref()
453                    .map(|m| m.to_device(xs.device()).unwrap())
454                    .as_ref(),
455                seqlen_offsets,
456                &mut cache[i],
457                scalings.clone(),
458                self.xlora_classifier
459                    .as_ref()
460                    .map(|classifier| classifier.get_global_scaling_weight())
461                    .unwrap_or(1.0),
462                is_scaling_pass,
463                flash_params,
464            )?;
465            let ys = (ys + residual)?;
466            let residual = &ys;
467            let ys = ys.apply(&layer.ffn_norm)?;
468            let ys = layer.mlp.forward(
469                &ys,
470                scalings.clone(),
471                self.xlora_classifier
472                    .as_ref()
473                    .map(|classifier| classifier.get_global_scaling_weight())
474                    .unwrap_or(1.0),
475                is_scaling_pass,
476            )?;
477            xs = (ys + residual)?
478        }
479        let xs = xs.to_device(&self.device)?;
480        xs.apply(&self.output_norm)
481    }
482
483    #[allow(clippy::too_many_arguments)]
484    pub fn forward(
485        &self,
486        input_ids: &Tensor,
487        input_ids_full: &Tensor,
488        seqlen_offsets: &[usize],
489        seqlen_offsets_full: &[usize],
490        no_kv_cache: bool,
491        non_granular_state: &Option<NonGranularState>,
492        context_lens: Vec<(usize, usize)>,
493        flash_params: &FlashParams,
494        flash_params_full: &FlashParams,
495    ) -> Result<Tensor> {
496        if self.xlora_classifier.is_some() {
497            let scalings = self.get_scalings(
498                input_ids,
499                input_ids_full,
500                seqlen_offsets,
501                seqlen_offsets_full,
502                no_kv_cache,
503                non_granular_state,
504                &vec![usize::MAX; context_lens.len()],
505                flash_params,
506                flash_params_full,
507            )?;
508
509            if no_kv_cache {
510                extract_logits(
511                    &self.output.lora_forward(
512                        &self
513                            .inner_forward(
514                                input_ids_full,
515                                seqlen_offsets_full,
516                                Some(scalings),
517                                true,
518                                no_kv_cache,
519                                None,
520                                flash_params_full,
521                            )?
522                            .contiguous()?,
523                        None,
524                        1.0,
525                        None,
526                    )?,
527                    context_lens,
528                )
529            } else {
530                // is_full_pass=true is ok because no_kv_cache=false
531                extract_logits(
532                    &self.output.lora_forward(
533                        &self
534                            .inner_forward(
535                                input_ids,
536                                seqlen_offsets,
537                                Some(scalings),
538                                true,
539                                no_kv_cache,
540                                None,
541                                flash_params,
542                            )?
543                            .contiguous()?,
544                        None,
545                        1.0,
546                        None,
547                    )?,
548                    context_lens,
549                )
550            }
551        } else {
552            extract_logits(
553                &self.output.lora_forward(
554                    &self
555                        .inner_forward(
556                            input_ids,
557                            seqlen_offsets,
558                            None,
559                            false,
560                            no_kv_cache,
561                            None,
562                            flash_params,
563                        )?
564                        .contiguous()?,
565                    None,
566                    1.0,
567                    None,
568                )?,
569                context_lens,
570            )
571        }
572    }
573}
574
575impl ScalingsMaker for ModelWeights {
576    fn dtype(&self) -> DType {
577        DType::F32 // for dummy scalings
578    }
579    fn get_cache(&self) -> &EitherCache {
580        &self.cache
581    }
582    fn get_classifier(&self) -> &XLoraClassifier {
583        self.xlora_classifier.as_ref().unwrap()
584    }
585    fn forward(
586        &self,
587        input_ids: &Tensor,
588        seqlen_offsets: &[usize],
589        scalings: Tensor,
590        is_full_pass: bool,
591        no_kv_cache: bool,
592        is_scaling_pass: Option<f64>,
593        _context_lens: &[usize],
594        flash_params: &FlashParams,
595    ) -> Result<Tensor> {
596        self.inner_forward(
597            input_ids,
598            seqlen_offsets,
599            Some(scalings),
600            is_full_pass,
601            no_kv_cache,
602            is_scaling_pass,
603            flash_params,
604        )
605    }
606}