mistralrs_core/xlora_models/
quantized_phi3.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::collections::HashMap;
4
5use crate::attention::SdpaParams;
6use crate::device_map::DeviceMapper;
7use crate::gguf::Content;
8use crate::layers::CausalMasker;
9use crate::layers::RmsNorm;
10use crate::layers::Sdpa;
11use crate::lora::get_lora_cfg;
12use crate::lora::LinearLayerLike;
13use crate::lora::LoraConfig;
14use crate::lora::Merge;
15use crate::lora::Ordering;
16use crate::lora::QLoraLinear;
17use crate::pipeline::extract_logits;
18use crate::pipeline::text_models_inputs_processor::FlashParams;
19use crate::pipeline::EitherCache;
20use crate::utils::progress::{new_multi_progress, NiceProgressBar};
21use candle_core::quantized::QMatMul;
22use candle_core::quantized::QTensor;
23use candle_core::{DType, Device, IndexOp, Module, Result, Tensor, D};
24use candle_nn::Embedding;
25use mistralrs_quant::ShardedVarBuilder;
26use tqdm::Iter;
27use tracing::info;
28
29use super::classifier::XLoraClassifier;
30use super::verify_sanity_adapters;
31use super::Cache;
32use super::NonGranularState;
33use super::ScalingsMaker;
34use super::XLoraConfig;
35use crate::models::quantized_phi3::PropsGGUF;
36use crate::utils::gguf_metadata::ContentMetadata;
37use crate::utils::model_config as ModelConfig;
38
39const SUPPORTED_LAYERS: [&str; 5] = [
40    "self_attn.qkv_proj",
41    "self_attn.o_proj",
42    "mlp.gate_up_proj",
43    "mlp.down_proj",
44    "lm_head",
45];
46
47#[derive(Debug)]
48struct Mlp {
49    ffn_up: QLoraLinear,
50    ffn_down: QLoraLinear,
51    i_size: usize,
52}
53
54impl Mlp {
55    fn forward(
56        &self,
57        xs: &Tensor,
58        scalings: Option<Tensor>,
59        global_scaling_weight: f64,
60        is_scaling_pass: Option<f64>,
61    ) -> Result<Tensor> {
62        let up_states = self.ffn_up.lora_forward(
63            xs,
64            scalings.clone(),
65            global_scaling_weight,
66            is_scaling_pass,
67        )?;
68        let gate = up_states.narrow(D::Minus1, 0, self.i_size)?;
69        let up_states = up_states.narrow(D::Minus1, self.i_size, self.i_size)?;
70        let up_states = (up_states * gate.silu()?)?;
71        self.ffn_down
72            .lora_forward(&up_states, scalings, global_scaling_weight, is_scaling_pass)
73    }
74}
75
76fn rms_norm(w: QTensor, eps: f64) -> Result<RmsNorm> {
77    let w = w.dequantize(&w.device())?;
78    let rms = RmsNorm::from_w(w, eps)?;
79    Ok(rms)
80}
81
82struct LayerWeights {
83    attn_qkv: QLoraLinear,
84    attn_output: QLoraLinear,
85    attn_norm: RmsNorm,
86    ffn_norm: RmsNorm,
87    mlp: Mlp,
88    n_head: usize,
89    n_kv_head: usize,
90    head_dim: usize,
91    cos: Tensor,
92    sin: Tensor,
93    sliding_window: usize,
94    sdpa_params: SdpaParams,
95    dtype: DType,
96}
97
98impl LayerWeights {
99    fn apply_rotary_emb(&self, xs: &Tensor, seqlen_offsets: &[usize]) -> Result<Tensor> {
100        let (_b_sz, _h, seq_len, _n_embd) = xs.dims4()?;
101        let mut outputs = Vec::new();
102        for (i, offset) in seqlen_offsets.iter().enumerate() {
103            let cos = self.cos.narrow(0, *offset, seq_len)?;
104            let sin = self.sin.narrow(0, *offset, seq_len)?;
105            outputs.push(candle_nn::rotary_emb::rope(
106                &xs.i(i)?.unsqueeze(0)?.contiguous()?,
107                &cos,
108                &sin,
109            )?);
110        }
111        Tensor::cat(&outputs, 0)
112    }
113
114    #[allow(clippy::too_many_arguments)]
115    fn forward_attn(
116        &self,
117        x: &Tensor,
118        mask: Option<&Tensor>,
119        seqlen_offsets: &[usize],
120        kv_cache: &mut Option<(Tensor, Tensor)>,
121        scalings: Option<Tensor>,
122        global_scaling_weight: f64,
123        is_scaling_pass: Option<f64>,
124        flash_params: &FlashParams,
125    ) -> Result<Tensor> {
126        let (b_sz, seq_len, n_embd) = x.dims3()?;
127        let qkv = self
128            .attn_qkv
129            .lora_forward(x, scalings.clone(), global_scaling_weight, is_scaling_pass)?
130            .to_dtype(self.dtype)?;
131
132        let query_pos = self.n_head * self.head_dim;
133        let q = qkv.narrow(D::Minus1, 0, query_pos)?;
134        let k = qkv.narrow(D::Minus1, query_pos, self.n_kv_head * self.head_dim)?;
135        let v = qkv.narrow(
136            D::Minus1,
137            query_pos + self.n_kv_head * self.head_dim,
138            self.n_kv_head * self.head_dim,
139        )?;
140
141        let (q, k, v) = if seq_len != 1 {
142            let q = q
143                .reshape((b_sz, seq_len, self.n_head, self.head_dim))?
144                .transpose(1, 2)?;
145            let k = k
146                .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
147                .transpose(1, 2)?;
148            let v = v
149                .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
150                .transpose(1, 2)?;
151            (q, k, v)
152        } else {
153            let q = q.reshape((b_sz, self.n_head, seq_len, self.head_dim))?;
154            let k = k.reshape((b_sz, self.n_kv_head, seq_len, self.head_dim))?;
155            let v = v.reshape((b_sz, self.n_kv_head, seq_len, self.head_dim))?;
156            (q, k, v)
157        };
158
159        let q = self.apply_rotary_emb(&q, seqlen_offsets)?.contiguous()?;
160        let k = self.apply_rotary_emb(&k, seqlen_offsets)?;
161
162        let (k, v, attn_mask) =
163            Cache::update_kv_cache_sliding_window(kv_cache, k, v, mask, Some(self.sliding_window))?;
164
165        let y = Sdpa.run_attention(
166            &q,
167            &k,
168            &v,
169            attn_mask.as_ref(),
170            Some(flash_params),
171            &self.sdpa_params,
172        )?;
173
174        let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
175        let y = self.attn_output.lora_forward(
176            &y.to_dtype(x.dtype())?,
177            scalings,
178            global_scaling_weight,
179            is_scaling_pass,
180        )?;
181        Ok(y)
182    }
183}
184
185pub struct ModelWeights {
186    tok_embeddings: Embedding,
187    layers: Vec<LayerWeights>,
188    output_norm: RmsNorm,
189    output: QLoraLinear,
190    mapper: Option<Box<dyn DeviceMapper + Send + Sync>>,
191    pub device: Device,
192    pub cache: EitherCache,
193    pub max_seq_len: usize,
194    xlora_classifier: Option<XLoraClassifier>,
195    dtype: DType,
196}
197
198fn precomput_freqs_cis(
199    head_dim: usize,
200    freq_base: f32,
201    device: &Device,
202    context_window: usize,
203    dtype: DType,
204) -> Result<(Tensor, Tensor)> {
205    let theta: Vec<_> = (0..head_dim)
206        .step_by(2)
207        .map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32))
208        .collect();
209    let theta = Tensor::new(theta.as_slice(), device)?;
210    let idx_theta = Tensor::arange(0, context_window as u32, device)?
211        .to_dtype(DType::F32)?
212        .reshape((context_window, 1))?
213        .matmul(&theta.reshape((1, theta.elem_count()))?)?;
214    let cos = idx_theta.cos()?.to_dtype(dtype)?;
215    let sin = idx_theta.sin()?.to_dtype(dtype)?;
216    Ok((cos, sin))
217}
218
219impl ModelConfig::FromAdapterGGUF for ModelWeights {
220    #[allow(clippy::too_many_arguments)]
221    fn from_gguf<R: std::io::Seek + std::io::Read>(
222        mut ct: Content<'_, R>,
223        device: &Device,
224        lora_config: &[((String, String), LoraConfig)],
225        vb: &ShardedVarBuilder,
226        ordering: &Ordering,
227        xlora_config: Option<XLoraConfig>,
228        mapper: Box<dyn DeviceMapper + Send + Sync>,
229        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
230        dtype: DType,
231    ) -> Result<Self> {
232        verify_sanity_adapters(ordering, &SUPPORTED_LAYERS)?;
233
234        // Parameter extraction from metadata.
235        let metadata = ContentMetadata {
236            path_prefix: "phi3",
237            metadata: ct.get_metadata(),
238        };
239        let PropsGGUF {
240            head_count,
241            head_count_kv,
242            block_count,
243            embedding_length,
244            i_size,
245            rope_dim,
246            rms_eps,
247            context_window,
248        } = PropsGGUF::try_from(metadata).or_else(|err| candle_core::bail!("{err}"))?;
249
250        let (cos, sin) = precomput_freqs_cis(rope_dim, 10_000., device, context_window, dtype)?;
251
252        let tok_embeddings = ct.tensor("token_embd.weight", device)?;
253        let tok_embeddings = tok_embeddings.dequantize(device)?;
254        let output_norm = rms_norm(ct.tensor("output_norm.weight", device)?, rms_eps)?;
255        let output = ct.tensor("output.weight", device)?;
256        let mut layers = Vec::with_capacity(block_count);
257
258        let mut count = 0;
259        for layer_idx in NiceProgressBar::<_, 'b'>(
260            0..block_count,
261            "Loading repeating layers",
262            &new_multi_progress(),
263        ) {
264            let prefix = format!("blk.{layer_idx}");
265            let device = mapper.device_for(layer_idx, false).unwrap_or(device);
266            let ffn_up = ct.tensor(&format!("{prefix}.ffn_up.weight"), device)?;
267            let ffn_down = ct.tensor(&format!("{prefix}.ffn_down.weight"), device)?;
268            let cfg_up = get_lora_cfg(&ffn_up);
269            let cfg_down = get_lora_cfg(&ffn_down);
270            let mlp = Mlp {
271                ffn_up: QLoraLinear::new(
272                    QMatMul::from_qtensor(ffn_up)?,
273                    &cfg_up,
274                    lora_config,
275                    vb,
276                    ordering,
277                    format!("{prefix}.mlp.gate_up_proj"),
278                    &mut count,
279                    preload_adapters,
280                )?,
281                ffn_down: QLoraLinear::new(
282                    QMatMul::from_qtensor(ffn_down)?,
283                    &cfg_down,
284                    lora_config,
285                    vb,
286                    ordering,
287                    format!("{prefix}.mlp.down_proj"),
288                    &mut count,
289                    preload_adapters,
290                )?,
291                i_size,
292            };
293            let attn_norm = rms_norm(
294                ct.tensor(&format!("{prefix}.attn_norm.weight"), device)?,
295                rms_eps,
296            )?;
297            let ffn_norm = rms_norm(
298                ct.tensor(&format!("{prefix}.ffn_norm.weight"), device)?,
299                rms_eps,
300            )?;
301            let qkv = ct.tensor(&format!("{prefix}.attn_qkv.weight"), device)?;
302            let output = ct.tensor(&format!("{prefix}.attn_output.weight"), device)?;
303            let cfg_qkv = get_lora_cfg(&qkv);
304            let cfg_out = get_lora_cfg(&output);
305            let head_dim = embedding_length / head_count;
306            layers.push(LayerWeights {
307                attn_qkv: QLoraLinear::new(
308                    QMatMul::from_qtensor(qkv)?,
309                    &cfg_qkv,
310                    lora_config,
311                    vb,
312                    ordering,
313                    format!("{prefix}.self_attn.qkv_proj"),
314                    &mut count,
315                    preload_adapters,
316                )?,
317                attn_output: QLoraLinear::new(
318                    QMatMul::from_qtensor(output)?,
319                    &cfg_out,
320                    lora_config,
321                    vb,
322                    ordering,
323                    format!("{prefix}.self_attn.o_proj"),
324                    &mut count,
325                    preload_adapters,
326                )?,
327                attn_norm,
328                ffn_norm,
329                mlp,
330                n_head: head_count,
331                n_kv_head: head_count_kv,
332                head_dim: embedding_length / head_count,
333                cos: cos.to_device(device)?,
334                sin: sin.to_device(device)?,
335                sliding_window: context_window,
336                sdpa_params: SdpaParams {
337                    n_kv_groups: head_count / head_count_kv,
338                    softcap: None,
339                    softmax_scale: 1.0 / (head_dim as f32).sqrt(),
340                    sliding_window: Some(context_window),
341                },
342                dtype,
343            })
344        }
345        if xlora_config.is_none() {
346            // We are now a LoRA model so we must merge the weights
347            info!("Merging LoRA adapters.");
348            for layer in layers.iter_mut().tqdm() {
349                layer.attn_qkv.merge_weights()?;
350                layer.attn_output.merge_weights()?;
351                layer.mlp.ffn_down.merge_weights()?;
352                layer.mlp.ffn_up.merge_weights()?;
353            }
354        }
355        let output_cfg = get_lora_cfg(&output);
356        let output = QLoraLinear::new(
357            QMatMul::from_qtensor(output)?,
358            &output_cfg,
359            lora_config,
360            vb,
361            ordering,
362            "lm_head".to_string(),
363            &mut count,
364            preload_adapters,
365        )?;
366        if xlora_config.is_some() && output.is_lora() {
367            // This is why we can pass dummy values (..., None, 1.0, None)?
368            candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA.");
369        }
370        Ok(Self {
371            tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
372            layers,
373            output_norm,
374            output,
375            mapper: Some(mapper),
376            device: device.clone(),
377            cache: EitherCache::Full(Cache::new(block_count, true)),
378            max_seq_len: context_window,
379            xlora_classifier: xlora_config.map(|xlora_config| {
380                XLoraClassifier::new(xlora_config, count, lora_config.len(), vb.clone(), true)
381                    .unwrap()
382            }),
383            dtype,
384        })
385    }
386}
387
388impl ModelWeights {
389    #[allow(clippy::too_many_arguments)]
390    pub fn inner_forward(
391        &self,
392        input_ids: &Tensor,
393        seqlen_offsets: &[usize],
394        scalings: Option<Tensor>,
395        is_full_pass: bool,
396        no_kv_cache: bool,
397        is_scaling_pass: Option<f64>,
398        flash_params: &FlashParams,
399    ) -> Result<Tensor> {
400        let mut xs = self.tok_embeddings.forward(input_ids)?;
401        let mut cache = if is_full_pass {
402            if no_kv_cache {
403                let mut new_cache = Vec::new();
404                for _ in 0..self.cache.full().xlora_lock().len() {
405                    new_cache.push(None);
406                }
407
408                self.cache.full().xlora_lock().clone_from(&new_cache);
409            }
410            self.cache.full().xlora_lock()
411        } else {
412            self.cache.full().lock()
413        };
414        let mask = CausalMasker.make_sliding_window_causal_mask_matrix(
415            input_ids,
416            &*cache,
417            Some(self.max_seq_len),
418            self.dtype,
419            self.layers[0].n_head,
420        )?;
421        for (i, layer) in self.layers.iter().enumerate() {
422            if let Some(ref mapper) = self.mapper {
423                xs = mapper.map(xs, i)?;
424            }
425            let residual = &xs;
426            let ys = xs.apply(&layer.attn_norm)?;
427            let ys = layer.forward_attn(
428                &ys,
429                mask.as_ref()
430                    .map(|m| m.to_device(xs.device()).unwrap())
431                    .as_ref(),
432                seqlen_offsets,
433                &mut cache[i],
434                scalings.clone(),
435                self.xlora_classifier
436                    .as_ref()
437                    .map(|classifier| classifier.get_global_scaling_weight())
438                    .unwrap_or(1.0),
439                is_scaling_pass,
440                flash_params,
441            )?;
442            let ys = (ys + residual)?;
443            let residual = &ys;
444            let ys = ys.apply(&layer.ffn_norm)?;
445            let ys = layer.mlp.forward(
446                &ys,
447                scalings.clone(),
448                self.xlora_classifier
449                    .as_ref()
450                    .map(|classifier| classifier.get_global_scaling_weight())
451                    .unwrap_or(1.0),
452                is_scaling_pass,
453            )?;
454            xs = (ys + residual)?
455        }
456        let xs = xs.to_device(&self.device)?;
457        xs.apply(&self.output_norm)
458    }
459
460    #[allow(clippy::too_many_arguments)]
461    pub fn forward(
462        &self,
463        input_ids: &Tensor,
464        input_ids_full: &Tensor,
465        seqlen_offsets: &[usize],
466        seqlen_offsets_full: &[usize],
467        no_kv_cache: bool,
468        non_granular_state: &Option<NonGranularState>,
469        context_lens: Vec<(usize, usize)>,
470        flash_params: &FlashParams,
471        flash_params_full: &FlashParams,
472    ) -> Result<Tensor> {
473        if self.xlora_classifier.is_some() {
474            let scalings = self.get_scalings(
475                input_ids,
476                input_ids_full,
477                seqlen_offsets,
478                seqlen_offsets_full,
479                no_kv_cache,
480                non_granular_state,
481                &vec![usize::MAX; context_lens.len()],
482                flash_params,
483                flash_params_full,
484            )?;
485
486            if no_kv_cache {
487                extract_logits(
488                    &self.output.lora_forward(
489                        &self
490                            .inner_forward(
491                                input_ids_full,
492                                seqlen_offsets_full,
493                                Some(scalings),
494                                true,
495                                no_kv_cache,
496                                None,
497                                flash_params_full,
498                            )?
499                            .contiguous()?,
500                        None,
501                        1.0,
502                        None,
503                    )?,
504                    context_lens,
505                )
506            } else {
507                // is_full_pass=true is ok because no_kv_cache=false
508                extract_logits(
509                    &self.output.lora_forward(
510                        &self
511                            .inner_forward(
512                                input_ids,
513                                seqlen_offsets,
514                                Some(scalings),
515                                true,
516                                no_kv_cache,
517                                None,
518                                flash_params,
519                            )?
520                            .contiguous()?,
521                        None,
522                        1.0,
523                        None,
524                    )?,
525                    context_lens,
526                )
527            }
528        } else {
529            extract_logits(
530                &self.output.lora_forward(
531                    &self
532                        .inner_forward(
533                            input_ids,
534                            seqlen_offsets,
535                            None,
536                            false,
537                            no_kv_cache,
538                            None,
539                            flash_params,
540                        )?
541                        .contiguous()?,
542                    None,
543                    1.0,
544                    None,
545                )?,
546                context_lens,
547            )
548        }
549    }
550}
551
552impl ScalingsMaker for ModelWeights {
553    fn dtype(&self) -> DType {
554        DType::F32 // for dummy scalings
555    }
556    fn get_cache(&self) -> &EitherCache {
557        &self.cache
558    }
559    fn get_classifier(&self) -> &XLoraClassifier {
560        self.xlora_classifier.as_ref().unwrap()
561    }
562    fn forward(
563        &self,
564        input_ids: &Tensor,
565        seqlen_offsets: &[usize],
566        scalings: Tensor,
567        is_full_pass: bool,
568        no_kv_cache: bool,
569        is_scaling_pass: Option<f64>,
570        _context_lens: &[usize],
571        flash_params: &FlashParams,
572    ) -> Result<Tensor> {
573        self.inner_forward(
574            input_ids,
575            seqlen_offsets,
576            Some(scalings),
577            is_full_pass,
578            no_kv_cache,
579            is_scaling_pass,
580            flash_params,
581        )
582    }
583}