mistralrs_core/xlora_models/
mixtral.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use crate::{
4    amoe::AnyMoeBaseModelMixin,
5    attention::SdpaParams,
6    layers::{self, Activation, RotaryEmbedding, Sdpa},
7    lora::{linear_no_bias, LinearLayerLike, LoraConfig, Ordering},
8    paged_attention::ModelConfigMetadata,
9    pipeline::{
10        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
11        EitherCache, IsqModel, NormalLoadingMetadata,
12    },
13    utils::progress::NiceProgressBar,
14};
15/// Mixtral Model
16/// https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py
17/// https://mistral.ai/news/mixtral-of-experts/
18use candle_core::{DType, Device, Module, Result, Tensor};
19use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
20use std::{collections::HashMap, sync::Arc};
21use tqdm::Iter;
22use tracing::info;
23
24use crate::{
25    device_map::DeviceMapper,
26    layers::{CausalMasker, RmsNorm},
27    models::mixtral::Config,
28    pipeline::{extract_logits, Cache, NormalModel},
29};
30
31use super::{classifier::XLoraClassifier, NonGranularState, ScalingsMaker, XLoraConfig};
32
33struct Attention {
34    q_proj: Arc<dyn LinearLayerLike + Send + Sync>,
35    k_proj: Arc<dyn LinearLayerLike + Send + Sync>,
36    v_proj: Arc<dyn LinearLayerLike + Send + Sync>,
37    o_proj: Arc<dyn LinearLayerLike + Send + Sync>,
38    num_heads: usize,
39    num_kv_heads: usize,
40    head_dim: usize,
41    rotary_emb: Arc<RotaryEmbedding>,
42    sliding_window: Option<usize>,
43    sdpa_params: SdpaParams,
44}
45
46impl Attention {
47    #[allow(clippy::too_many_arguments)]
48    fn new(
49        rotary_emb: Arc<RotaryEmbedding>,
50        cfg: &Config,
51        vb: ShardedVarBuilder,
52        lora_config: &[((String, String), LoraConfig)],
53        count: &mut usize,
54        ord: &Ordering,
55        mapper: &dyn DeviceMapper,
56        layer_idx: usize,
57        loading_isq: bool,
58        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
59    ) -> Result<Self> {
60        let hidden_sz = cfg.hidden_size;
61        let num_heads = cfg.num_attention_heads;
62        let num_kv_heads = cfg.num_key_value_heads;
63        let head_dim = hidden_sz / num_heads;
64        let q_proj = linear_no_bias(
65            hidden_sz,
66            num_heads * head_dim,
67            mapper.set_device(layer_idx, vb.pp("q_proj"), loading_isq),
68            mapper.set_device(layer_idx, vb.pp("q_proj"), false),
69            lora_config,
70            count,
71            ord,
72            preload_adapters,
73        )?;
74        let k_proj = linear_no_bias(
75            hidden_sz,
76            num_kv_heads * head_dim,
77            mapper.set_device(layer_idx, vb.pp("k_proj"), loading_isq),
78            mapper.set_device(layer_idx, vb.pp("k_proj"), false),
79            lora_config,
80            count,
81            ord,
82            preload_adapters,
83        )?;
84        let v_proj = linear_no_bias(
85            hidden_sz,
86            num_kv_heads * head_dim,
87            mapper.set_device(layer_idx, vb.pp("v_proj"), loading_isq),
88            mapper.set_device(layer_idx, vb.pp("v_proj"), false),
89            lora_config,
90            count,
91            ord,
92            preload_adapters,
93        )?;
94        let o_proj = linear_no_bias(
95            num_heads * head_dim,
96            hidden_sz,
97            mapper.set_device(layer_idx, vb.pp("o_proj"), loading_isq),
98            mapper.set_device(layer_idx, vb.pp("o_proj"), false),
99            lora_config,
100            count,
101            ord,
102            preload_adapters,
103        )?;
104        Ok(Self {
105            q_proj,
106            k_proj,
107            v_proj,
108            o_proj,
109            num_heads,
110            num_kv_heads,
111            head_dim,
112            rotary_emb,
113            sliding_window: cfg.sliding_window,
114            sdpa_params: SdpaParams {
115                n_kv_groups: num_heads / num_kv_heads,
116                use_flash_attn: cfg.use_flash_attn,
117                softcap: None,
118                softmax_scale: 1.0 / (head_dim as f32).sqrt(),
119                sliding_window: cfg.sliding_window,
120            },
121        })
122    }
123
124    #[allow(clippy::too_many_arguments)]
125    fn forward(
126        &self,
127        xs: &Tensor,
128        attention_mask: Option<&Tensor>,
129        seqlen_offsets: &[usize],
130        kv_cache: &mut Option<(Tensor, Tensor)>,
131        scalings: Option<Tensor>,
132        global_scaling_weight: f64,
133        is_scaling_pass: Option<f64>,
134        flash_params: &FlashParams,
135    ) -> Result<Tensor> {
136        let (b_sz, q_len, _) = xs.dims3()?;
137
138        let original_dtype = xs.dtype();
139        let mut xs = xs.clone();
140        if let Some(t) = self.q_proj.quantized_act_type() {
141            xs = xs.to_dtype(t)?;
142        }
143        let mut q = self.q_proj.lora_forward(
144            &xs,
145            scalings.clone(),
146            global_scaling_weight,
147            is_scaling_pass,
148        )?;
149        let mut k = self.k_proj.lora_forward(
150            &xs,
151            scalings.clone(),
152            global_scaling_weight,
153            is_scaling_pass,
154        )?;
155        let mut v = self.v_proj.lora_forward(
156            &xs,
157            scalings.clone(),
158            global_scaling_weight,
159            is_scaling_pass,
160        )?;
161        if self.q_proj.quantized_act_type().is_some() {
162            q = q.to_dtype(original_dtype)?;
163            k = k.to_dtype(original_dtype)?;
164            v = v.to_dtype(original_dtype)?;
165        }
166
167        let (q, k, v) = if q_len != 1 {
168            let q = q
169                .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
170                .transpose(1, 2)?;
171            let k = k
172                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
173                .transpose(1, 2)?;
174            let v = v
175                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
176                .transpose(1, 2)?;
177            (q, k, v)
178        } else {
179            let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
180            let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
181            let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
182            (q, k, v)
183        };
184
185        let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
186
187        let (k, v, attn_mask) = Cache::update_kv_cache_sliding_window(
188            kv_cache,
189            k,
190            v,
191            attention_mask,
192            self.sliding_window,
193            false,
194        )?;
195
196        let mut attn_output = Sdpa.run_attention(
197            &q,
198            &k,
199            &v,
200            attn_mask.as_ref(),
201            Some(flash_params),
202            &self.sdpa_params,
203        )?;
204
205        if let Some(t) = self.q_proj.quantized_act_type() {
206            attn_output = attn_output.to_dtype(t)?;
207        }
208        let mut res = self.o_proj.lora_forward(
209            &attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?,
210            scalings.clone(),
211            global_scaling_weight,
212            is_scaling_pass,
213        )?;
214        if self.q_proj.quantized_act_type().is_some() {
215            res = res.to_dtype(original_dtype)?;
216        }
217        Ok(res)
218    }
219}
220
221#[derive(Clone)]
222struct BlockSparseTop2MLP {
223    w1: Arc<dyn LinearLayerLike + Send + Sync>,
224    w2: Arc<dyn LinearLayerLike + Send + Sync>,
225    w3: Arc<dyn LinearLayerLike + Send + Sync>,
226    act_fn: Activation,
227}
228
229impl BlockSparseTop2MLP {
230    #[allow(clippy::too_many_arguments)]
231    fn new(
232        cfg: &Config,
233        vb: ShardedVarBuilder,
234        lora_config: &[((String, String), LoraConfig)],
235        count: &mut usize,
236        ord: &Ordering,
237        mapper: &dyn DeviceMapper,
238        layer_idx: usize,
239        loading_isq: bool,
240        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
241    ) -> Result<Self> {
242        let hidden_sz = cfg.hidden_size;
243        let intermediate_sz = cfg.intermediate_size;
244        let w1 = linear_no_bias(
245            hidden_sz,
246            intermediate_sz,
247            mapper.set_device(layer_idx, vb.pp("w1"), loading_isq),
248            mapper.set_device(layer_idx, vb.pp("w1"), false),
249            lora_config,
250            count,
251            ord,
252            preload_adapters,
253        )?;
254        let w2 = linear_no_bias(
255            intermediate_sz,
256            hidden_sz,
257            mapper.set_device(layer_idx, vb.pp("w2"), loading_isq),
258            mapper.set_device(layer_idx, vb.pp("w2"), false),
259            lora_config,
260            count,
261            ord,
262            preload_adapters,
263        )?;
264        let w3 = linear_no_bias(
265            hidden_sz,
266            intermediate_sz,
267            mapper.set_device(layer_idx, vb.pp("w3"), loading_isq),
268            mapper.set_device(layer_idx, vb.pp("w3"), false),
269            lora_config,
270            count,
271            ord,
272            preload_adapters,
273        )?;
274        Ok(Self {
275            w1,
276            w2,
277            w3,
278            act_fn: cfg.hidden_act,
279        })
280    }
281
282    fn forward(
283        &self,
284        xs: &Tensor,
285        scalings: Option<Tensor>,
286        global_scaling_weight: f64,
287        is_scaling_pass: Option<f64>,
288    ) -> Result<Tensor> {
289        let original_dtype = xs.dtype();
290        let mut xs = xs.clone();
291        if let Some(t) = self.w1.quantized_act_type() {
292            xs = xs.to_dtype(t)?;
293        }
294        let lhs = self
295            .w1
296            .lora_forward(
297                &xs,
298                scalings.clone(),
299                global_scaling_weight,
300                is_scaling_pass,
301            )?
302            .apply(&self.act_fn)?;
303        let rhs = self.w3.lora_forward(
304            &xs,
305            scalings.clone(),
306            global_scaling_weight,
307            is_scaling_pass,
308        )?;
309        let mut res = self.w2.lora_forward(
310            &(lhs * rhs)?,
311            scalings.clone(),
312            global_scaling_weight,
313            is_scaling_pass,
314        )?;
315        if self.w1.quantized_act_type().is_some() {
316            res = res.to_dtype(original_dtype)?;
317        }
318        Ok(res)
319    }
320}
321
322#[derive(Clone)]
323struct SparseMoeBlock {
324    gate: Arc<dyn LinearLayerLike + Send + Sync>,
325    experts: Vec<BlockSparseTop2MLP>,
326    num_experts_per_tok: usize,
327}
328
329impl SparseMoeBlock {
330    #[allow(clippy::too_many_arguments)]
331    fn new(
332        cfg: &Config,
333        vb: ShardedVarBuilder,
334        lora_config: &[((String, String), LoraConfig)],
335        count: &mut usize,
336        ord: &Ordering,
337        mapper: &dyn DeviceMapper,
338        layer_idx: usize,
339        loading_isq: bool,
340        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
341    ) -> Result<Self> {
342        let gate = linear_no_bias(
343            cfg.hidden_size,
344            cfg.num_local_experts,
345            mapper.set_device(layer_idx, vb.pp("gate"), loading_isq),
346            mapper.set_device(layer_idx, vb.pp("gate"), false),
347            lora_config,
348            count,
349            ord,
350            preload_adapters,
351        )?;
352        let mut experts = Vec::with_capacity(cfg.num_local_experts);
353        let vb = vb.pp("experts");
354        for idx in 0..cfg.num_local_experts {
355            let expert = BlockSparseTop2MLP::new(
356                cfg,
357                vb.pp(idx),
358                lora_config,
359                count,
360                ord,
361                mapper,
362                layer_idx,
363                loading_isq,
364                preload_adapters,
365            )?;
366            experts.push(expert)
367        }
368        Ok(SparseMoeBlock {
369            gate,
370            experts,
371            num_experts_per_tok: cfg.num_experts_per_tok,
372        })
373    }
374
375    fn forward(
376        &self,
377        xs: &Tensor,
378        scalings: Option<Tensor>,
379        global_scaling_weight: f64,
380        is_scaling_pass: Option<f64>,
381    ) -> Result<Tensor> {
382        let (b_size, seq_len, hidden_dim) = xs.dims3()?;
383        let xs = xs.reshape(((), hidden_dim))?;
384
385        let original_dtype = xs.dtype();
386        let mut xs = xs.clone();
387        if let Some(t) = self.gate.quantized_act_type() {
388            xs = xs.to_dtype(t)?;
389        }
390        let mut router_logits = self.gate.lora_forward(
391            &xs,
392            scalings.clone(),
393            global_scaling_weight,
394            is_scaling_pass,
395        )?;
396        if self.gate.quantized_act_type().is_some() {
397            router_logits = router_logits.to_dtype(original_dtype)?;
398        }
399
400        let routing_weights = candle_nn::ops::softmax_last_dim(&router_logits)?;
401
402        // In order to extract topk, we extract the data from the tensor and manipulate it
403        // directly. Maybe we will want to use some custom ops instead at some point.
404        let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::<f32>()?;
405
406        // routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
407        // top_x contains the row indexes to evaluate for each expert.
408        let mut top_x = vec![vec![]; self.experts.len()];
409        let mut selected_rws = vec![vec![]; self.experts.len()];
410        for (row_idx, rw) in routing_weights.iter().enumerate() {
411            let mut dst = (0..rw.len() as u32).collect::<Vec<u32>>();
412            dst.sort_by(|&i, &j| rw[j as usize].total_cmp(&rw[i as usize]));
413            let mut sum_routing_weights = 0f32;
414            for &expert_idx in dst.iter().take(self.num_experts_per_tok) {
415                let expert_idx = expert_idx as usize;
416                let routing_weight = rw[expert_idx];
417                sum_routing_weights += routing_weight;
418                top_x[expert_idx].push(row_idx as u32);
419            }
420            for &expert_idx in dst.iter().take(self.num_experts_per_tok) {
421                let expert_idx = expert_idx as usize;
422                let routing_weight = rw[expert_idx];
423                selected_rws[expert_idx].push(routing_weight / sum_routing_weights)
424            }
425        }
426
427        // routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
428        // expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
429
430        let mut ys = xs.zeros_like()?;
431        for (expert_idx, expert_layer) in self.experts.iter().enumerate() {
432            let top_x = &top_x[expert_idx];
433            if top_x.is_empty() {
434                continue;
435            }
436            let top_x = Tensor::new(top_x.as_slice(), xs.device())?;
437            let selected_rws =
438                Tensor::new(selected_rws[expert_idx].as_slice(), xs.device())?.reshape(((), 1))?;
439            // Index the correct hidden states and compute the expert hidden state for
440            // the current expert. We need to make sure to multiply the output hidden
441            // states by `routing_weights` on the corresponding tokens (top-1 and top-2)
442            let current_state = xs.index_select(&top_x, 0)?.reshape(((), hidden_dim))?;
443            // current_hidden_states = expert_layer(current_state, routing_weights[top_x_list, idx_list, None])
444            let current_hidden_states = expert_layer.forward(
445                &current_state,
446                scalings.clone(),
447                global_scaling_weight,
448                is_scaling_pass,
449            )?;
450            let current_hidden_states = current_hidden_states.broadcast_mul(&selected_rws)?;
451            ys = ys.index_add(&top_x, &current_hidden_states, 0)?;
452        }
453
454        let ys = ys.reshape((b_size, seq_len, hidden_dim))?;
455        Ok(ys)
456    }
457}
458
459struct DecoderLayer {
460    self_attn: Attention,
461    block_sparse_moe: SparseMoeBlock,
462    input_layernorm: RmsNorm,
463    post_attention_layernorm: RmsNorm,
464}
465
466impl DecoderLayer {
467    #[allow(clippy::too_many_arguments)]
468    fn new(
469        rotary_emb: Arc<RotaryEmbedding>,
470        cfg: &Config,
471        vb: ShardedVarBuilder,
472        lora_config: &[((String, String), LoraConfig)],
473        count: &mut usize,
474        ord: &Ordering,
475        mapper: &dyn DeviceMapper,
476        layer_idx: usize,
477        loading_isq: bool,
478        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
479    ) -> Result<Self> {
480        let self_attn = Attention::new(
481            rotary_emb,
482            cfg,
483            vb.pp("self_attn"),
484            lora_config,
485            count,
486            ord,
487            mapper,
488            layer_idx,
489            loading_isq,
490            preload_adapters,
491        )?;
492        let block_sparse_moe = SparseMoeBlock::new(
493            cfg,
494            vb.pp("block_sparse_moe"),
495            lora_config,
496            count,
497            ord,
498            mapper,
499            layer_idx,
500            loading_isq,
501            preload_adapters,
502        )?;
503        let input_layernorm = RmsNorm::new(
504            cfg.hidden_size,
505            cfg.rms_norm_eps,
506            mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
507        )?;
508        let post_attention_layernorm = RmsNorm::new(
509            cfg.hidden_size,
510            cfg.rms_norm_eps,
511            mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
512        )?;
513        Ok(Self {
514            self_attn,
515            block_sparse_moe,
516            input_layernorm,
517            post_attention_layernorm,
518        })
519    }
520
521    #[allow(clippy::too_many_arguments)]
522    fn forward(
523        &self,
524        xs: &Tensor,
525        attention_mask: Option<&Tensor>,
526        seqlen_offsets: &[usize],
527        kv_cache: &mut Option<(Tensor, Tensor)>,
528        scalings: Option<Tensor>,
529        global_scaling_weight: f64,
530        is_scaling_pass: Option<f64>,
531        flash_params: &FlashParams,
532    ) -> Result<Tensor> {
533        let residual = xs;
534        let xs = self.input_layernorm.forward(xs)?;
535        let xs = self.self_attn.forward(
536            &xs,
537            attention_mask,
538            seqlen_offsets,
539            kv_cache,
540            scalings.clone(),
541            global_scaling_weight,
542            is_scaling_pass,
543            flash_params,
544        )?;
545        let xs = (xs + residual)?;
546        let residual = &xs;
547        let xs = self
548            .block_sparse_moe
549            .forward(
550                &xs.apply(&self.post_attention_layernorm)?,
551                scalings.clone(),
552                global_scaling_weight,
553                is_scaling_pass,
554            )?
555            .to_dtype(residual.dtype())?;
556        residual + xs
557    }
558}
559
560pub struct XLoraModel {
561    embed_tokens: candle_nn::Embedding,
562    layers: Vec<DecoderLayer>,
563    norm: RmsNorm,
564    lm_head: Arc<dyn LinearLayerLike + Send + Sync>,
565    sliding_window: Option<usize>,
566    device: Device,
567    cache: EitherCache,
568    dtype: DType,
569    max_seq_len: usize,
570    xlora_classifier: Option<XLoraClassifier>,
571    mapper: Box<dyn DeviceMapper + Send + Sync>,
572    cfg: ModelConfigMetadata,
573}
574
575impl XLoraModel {
576    #[allow(clippy::too_many_arguments)]
577    pub fn new(
578        cfg: &Config,
579        vb: ShardedVarBuilder,
580        lora_config: &[((String, String), LoraConfig)],
581        xlora_config: Option<XLoraConfig>,
582        xlora_ordering: Ordering,
583        is_gptx: bool,
584        normal_loading_metadata: NormalLoadingMetadata,
585        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
586    ) -> Result<Self> {
587        if let Some(ref quant_cfg) = &cfg.quantization_config {
588            tracing::info!(
589                "Using {} quantization: {}.",
590                quant_cfg.quant_method.to_string(),
591                quant_cfg.get_bits_name(&vb)
592            );
593        }
594        let mapper = normal_loading_metadata.mapper;
595        let vb_m = vb.pp("model");
596
597        let embed_tokens = layers::embedding(
598            cfg.vocab_size,
599            cfg.hidden_size,
600            mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
601        )?;
602        let head_dim = cfg.hidden_size / cfg.num_attention_heads;
603        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
604        let vb_l = vb_m.pp("layers");
605        let mut ropes = HashMap::new();
606        for layer_idx in 0..cfg.num_hidden_layers {
607            let device = mapper
608                .device_for(layer_idx, false)
609                .unwrap_or(&normal_loading_metadata.real_device);
610            ropes.insert(
611                device.location(),
612                Arc::new(RotaryEmbedding::new(
613                    cfg.rope_theta as f32,
614                    head_dim,
615                    cfg.max_position_embeddings,
616                    device,
617                    is_gptx,
618                    vb_m.dtype(),
619                )?),
620            );
621        }
622
623        let mut count = 0;
624        for layer_idx in NiceProgressBar::<_, 'b'>(
625            0..cfg.num_hidden_layers,
626            "Loading repeating layers",
627            &normal_loading_metadata.multi_progress,
628        ) {
629            let device = mapper
630                .device_for(layer_idx, false)
631                .unwrap_or(&normal_loading_metadata.real_device);
632            let rotary_emb = ropes
633                .get(&device.location())
634                .expect("No RoPE for device location!")
635                .clone();
636            let layer = DecoderLayer::new(
637                rotary_emb.clone(),
638                cfg,
639                vb_l.pp(layer_idx),
640                lora_config,
641                &mut count,
642                &xlora_ordering,
643                &*mapper,
644                layer_idx,
645                normal_loading_metadata.loading_isq,
646                preload_adapters,
647            )?;
648            layers.push(layer)
649        }
650        if xlora_config.is_none() && preload_adapters.is_none() {
651            // We are now a LoRA model so we must merge the weights
652            info!("Merging LoRA adapters.");
653            for layer in layers.iter_mut().tqdm() {
654                Arc::get_mut(&mut layer.self_attn.k_proj)
655                    .unwrap()
656                    .merge_weights()?;
657                Arc::get_mut(&mut layer.self_attn.o_proj)
658                    .unwrap()
659                    .merge_weights()?;
660                Arc::get_mut(&mut layer.self_attn.q_proj)
661                    .unwrap()
662                    .merge_weights()?;
663                Arc::get_mut(&mut layer.self_attn.v_proj)
664                    .unwrap()
665                    .merge_weights()?;
666
667                Arc::get_mut(&mut layer.block_sparse_moe.gate)
668                    .unwrap()
669                    .merge_weights()?;
670                for expert in layer.block_sparse_moe.experts.iter_mut() {
671                    Arc::get_mut(&mut expert.w1).unwrap().merge_weights()?;
672                    Arc::get_mut(&mut expert.w2).unwrap().merge_weights()?;
673                    Arc::get_mut(&mut expert.w3).unwrap().merge_weights()?;
674                }
675            }
676        }
677        let norm = RmsNorm::new(
678            cfg.hidden_size,
679            cfg.rms_norm_eps,
680            mapper.set_nm_device(vb_m.pp("norm"), false),
681        )?;
682        let lm_head = linear_no_bias(
683            cfg.hidden_size,
684            cfg.vocab_size,
685            mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
686            mapper.set_nm_device(vb.pp("lm_head"), false),
687            lora_config,
688            &mut count,
689            &xlora_ordering,
690            preload_adapters,
691        )?;
692        if xlora_config.is_some() && lm_head.is_lora() {
693            // This is why we can pass dummy values (..., None, 1.0, None)?
694            candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA.");
695        }
696        Ok(Self {
697            embed_tokens,
698            layers,
699            norm,
700            lm_head,
701            sliding_window: cfg.sliding_window,
702            device: normal_loading_metadata.real_device,
703            dtype: vb.dtype(),
704            cache: EitherCache::Full(Cache::new(cfg.num_hidden_layers, false)),
705            max_seq_len: cfg.max_position_embeddings,
706            xlora_classifier: xlora_config.map(|xlora_config| {
707                XLoraClassifier::new(xlora_config, count, lora_config.len(), vb, false).unwrap()
708            }),
709            mapper,
710            cfg: ModelConfigMetadata {
711                max_seq_len: cfg.max_position_embeddings,
712                num_layers: cfg.num_hidden_layers,
713                hidden_size: cfg.hidden_size,
714                num_kv_heads: cfg.num_key_value_heads,
715                num_attn_heads: cfg.num_attention_heads,
716                sliding_window: cfg.sliding_window,
717                k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
718                v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
719            },
720        })
721    }
722
723    #[allow(clippy::too_many_arguments)]
724    fn inner_forward(
725        &self,
726        input_ids: &Tensor,
727        seqlen_offsets: &[usize],
728        scalings: Option<Tensor>,
729        is_full_pass: bool,
730        no_kv_cache: bool,
731        is_scaling_pass: Option<f64>,
732        flash_params: &FlashParams,
733    ) -> Result<Tensor> {
734        let mut cache = if is_full_pass {
735            if no_kv_cache {
736                let mut new_cache = Vec::new();
737                for _ in 0..self.cache.full().xlora_lock().len() {
738                    new_cache.push(None);
739                }
740
741                self.cache.full().xlora_lock().clone_from(&new_cache);
742            }
743            self.cache.full().xlora_lock()
744        } else {
745            self.cache.full().lock()
746        };
747        let mut xs = self.embed_tokens.forward(input_ids)?;
748        let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
749            input_ids,
750            &*cache,
751            self.sliding_window,
752            xs.dtype(),
753            self.cfg.num_attn_heads,
754        )?;
755        for (i, layer) in self.layers.iter().enumerate() {
756            xs = self.mapper.map(xs, i)?;
757            xs = layer.forward(
758                &xs,
759                attention_mask
760                    .as_ref()
761                    .map(|m| m.to_device(xs.device()).unwrap())
762                    .as_ref(),
763                seqlen_offsets,
764                &mut cache[i],
765                scalings.clone(),
766                self.xlora_classifier
767                    .as_ref()
768                    .map(|classifier| classifier.get_global_scaling_weight())
769                    .unwrap_or(1.0),
770                is_scaling_pass,
771                flash_params,
772            )?
773        }
774        let xs = xs.to_device(&self.device)?;
775        xs.apply(&self.norm)
776    }
777
778    #[allow(clippy::too_many_arguments)]
779    pub fn forward(
780        &self,
781        input_ids: &Tensor,
782        input_ids_full: &Tensor,
783        seqlen_offsets: &[usize],
784        seqlen_offsets_full: &[usize],
785        no_kv_cache: bool,
786        non_granular_state: &Option<NonGranularState>,
787        context_lens: Vec<(usize, usize)>,
788        flash_params: &FlashParams,
789        flash_params_full: &FlashParams,
790    ) -> Result<Tensor> {
791        if self.xlora_classifier.is_some() {
792            let scalings = self.get_scalings(
793                input_ids,
794                input_ids_full,
795                seqlen_offsets,
796                seqlen_offsets_full,
797                no_kv_cache,
798                non_granular_state,
799                &vec![usize::MAX; context_lens.len()],
800                flash_params,
801                flash_params_full,
802            )?;
803
804            if no_kv_cache {
805                let mut res = self
806                    .inner_forward(
807                        input_ids_full,
808                        seqlen_offsets_full,
809                        Some(scalings),
810                        true,
811                        no_kv_cache,
812                        None,
813                        flash_params_full,
814                    )?
815                    .contiguous()?;
816                if let Some(t) = self.lm_head.quantized_act_type() {
817                    res = res.to_dtype(t)?;
818                }
819                extract_logits(
820                    &self.lm_head.lora_forward(&res, None, 1.0, None)?,
821                    context_lens,
822                )
823            } else {
824                // is_full_pass=true is ok because no_kv_cache=false
825                let mut res = self
826                    .inner_forward(
827                        input_ids,
828                        seqlen_offsets,
829                        Some(scalings),
830                        true,
831                        no_kv_cache,
832                        None,
833                        flash_params,
834                    )?
835                    .contiguous()?;
836                if let Some(t) = self.lm_head.quantized_act_type() {
837                    res = res.to_dtype(t)?;
838                }
839                extract_logits(
840                    &self.lm_head.lora_forward(&res, None, 1.0, None)?,
841                    context_lens,
842                )
843            }
844        } else {
845            let mut res = self
846                .inner_forward(
847                    input_ids,
848                    seqlen_offsets,
849                    None,
850                    false,
851                    no_kv_cache,
852                    None,
853                    flash_params,
854                )?
855                .contiguous()?;
856            if let Some(t) = self.lm_head.quantized_act_type() {
857                res = res.to_dtype(t)?;
858            }
859            extract_logits(
860                &self.lm_head.lora_forward(&res, None, 1.0, None)?,
861                context_lens,
862            )
863        }
864    }
865}
866
867impl IsqModel for XLoraModel {
868    fn get_layers(
869        &mut self,
870    ) -> (
871        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
872        &dyn DeviceMapper,
873    ) {
874        let mut tensors = Vec::new();
875        tensors.push((Arc::get_mut(&mut self.lm_head).unwrap().quant_inner(), None));
876        for (i, layer) in self.layers.iter_mut().enumerate() {
877            tensors.push((
878                Arc::get_mut(&mut layer.self_attn.q_proj)
879                    .unwrap()
880                    .quant_inner(),
881                Some(i),
882            ));
883            tensors.push((
884                Arc::get_mut(&mut layer.self_attn.k_proj)
885                    .unwrap()
886                    .quant_inner(),
887                Some(i),
888            ));
889            tensors.push((
890                Arc::get_mut(&mut layer.self_attn.v_proj)
891                    .unwrap()
892                    .quant_inner(),
893                Some(i),
894            ));
895            tensors.push((
896                Arc::get_mut(&mut layer.self_attn.o_proj)
897                    .unwrap()
898                    .quant_inner(),
899                Some(i),
900            ));
901            tensors.push((
902                Arc::get_mut(&mut layer.block_sparse_moe.gate)
903                    .unwrap()
904                    .quant_inner(),
905                Some(i),
906            ));
907            for expert in &mut layer.block_sparse_moe.experts {
908                tensors.push((Arc::get_mut(&mut expert.w1).unwrap().quant_inner(), Some(i)));
909                tensors.push((Arc::get_mut(&mut expert.w2).unwrap().quant_inner(), Some(i)));
910                tensors.push((Arc::get_mut(&mut expert.w3).unwrap().quant_inner(), Some(i)));
911            }
912        }
913        (tensors, &*self.mapper)
914    }
915
916    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
917        panic!("Cannot generate UQFF for an adapter model.")
918    }
919}
920
921impl NormalModel for XLoraModel {
922    fn forward(
923        &self,
924        _input_ids: &Tensor,
925        _seqlen_offsets: &[usize],
926        _context_lens: Vec<(usize, usize)>,
927        _position_ids: Vec<usize>,
928        _metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
929        _flash_params: &FlashParams,
930    ) -> Result<Tensor> {
931        unreachable!()
932    }
933    fn xlora_forward(
934        &self,
935        input_ids: &Tensor,
936        input_ids_full: &Tensor,
937        seqlen_offsets: &[usize],
938        seqlen_offsets_full: &[usize],
939        no_kv_cache: bool,
940        non_granular_state: &Option<crate::xlora_models::NonGranularState>,
941        context_lens: Vec<(usize, usize)>,
942        _position_ids: Vec<usize>,
943        flash_params: &FlashParams,
944        flash_params_full: &FlashParams,
945    ) -> Result<Tensor> {
946        self.forward(
947            input_ids,
948            input_ids_full,
949            seqlen_offsets,
950            seqlen_offsets_full,
951            no_kv_cache,
952            non_granular_state,
953            context_lens,
954            flash_params,
955            flash_params_full,
956        )
957    }
958    fn cache(&self) -> &EitherCache {
959        &self.cache
960    }
961    fn cache_mut(&mut self) -> &mut EitherCache {
962        &mut self.cache
963    }
964    fn device(&self) -> &Device {
965        &self.device
966    }
967    fn is_xlora(&self) -> bool {
968        true
969    }
970    fn max_seq_len(&self) -> usize {
971        self.max_seq_len
972    }
973    fn activate_adapters(&mut self, adapter_names: Vec<String>) -> Result<usize> {
974        if self.xlora_classifier.is_some() {
975            candle_core::bail!("Adapter activation is not supported for X-LoRA models as the adapter set must remain the same.");
976        }
977        let mut sum = 0;
978        for layer in self.layers.iter_mut() {
979            sum += Arc::get_mut(&mut layer.self_attn.k_proj)
980                .unwrap()
981                .activate(&adapter_names)?;
982            sum += Arc::get_mut(&mut layer.self_attn.o_proj)
983                .unwrap()
984                .activate(&adapter_names)?;
985            sum += Arc::get_mut(&mut layer.self_attn.q_proj)
986                .unwrap()
987                .activate(&adapter_names)?;
988            sum += Arc::get_mut(&mut layer.self_attn.v_proj)
989                .unwrap()
990                .activate(&adapter_names)?;
991
992            sum += Arc::get_mut(&mut layer.block_sparse_moe.gate)
993                .unwrap()
994                .activate(&adapter_names)?;
995            for expert in &mut layer.block_sparse_moe.experts {
996                sum += Arc::get_mut(&mut expert.w1)
997                    .unwrap()
998                    .activate(&adapter_names)?;
999                sum += Arc::get_mut(&mut expert.w2)
1000                    .unwrap()
1001                    .activate(&adapter_names)?;
1002                sum += Arc::get_mut(&mut expert.w3)
1003                    .unwrap()
1004                    .activate(&adapter_names)?;
1005            }
1006        }
1007        Ok(sum)
1008    }
1009    fn config(&self) -> &ModelConfigMetadata {
1010        &self.cfg
1011    }
1012}
1013
1014impl ScalingsMaker for XLoraModel {
1015    fn dtype(&self) -> DType {
1016        self.dtype
1017    }
1018    fn get_cache(&self) -> &EitherCache {
1019        &self.cache
1020    }
1021    fn get_classifier(&self) -> &XLoraClassifier {
1022        self.xlora_classifier.as_ref().unwrap()
1023    }
1024    fn forward(
1025        &self,
1026        input_ids: &Tensor,
1027        seqlen_offsets: &[usize],
1028        scalings: Tensor,
1029        is_full_pass: bool,
1030        no_kv_cache: bool,
1031        is_scaling_pass: Option<f64>,
1032        _context_lens: &[usize],
1033        flash_params: &FlashParams,
1034    ) -> Result<Tensor> {
1035        self.inner_forward(
1036            input_ids,
1037            seqlen_offsets,
1038            Some(scalings),
1039            is_full_pass,
1040            no_kv_cache,
1041            is_scaling_pass,
1042            flash_params,
1043        )
1044    }
1045}
1046
1047impl AnyMoeBaseModelMixin for XLoraModel {}