mistralrs_core/xlora_models/
mixtral.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use crate::{
4    amoe::AnyMoeBaseModelMixin,
5    attention::SdpaParams,
6    layers::{self, Activation, RotaryEmbedding, Sdpa},
7    lora::{linear_no_bias, LinearLayerLike, LoraConfig, Ordering},
8    paged_attention::ModelConfigMetadata,
9    pipeline::{
10        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
11        EitherCache, IsqModel, NormalLoadingMetadata,
12    },
13    utils::progress::NiceProgressBar,
14};
15/// Mixtral Model
16/// https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py
17/// https://mistral.ai/news/mixtral-of-experts/
18use candle_core::{DType, Device, Module, Result, Tensor};
19use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
20use std::{collections::HashMap, sync::Arc};
21use tqdm::Iter;
22use tracing::info;
23
24use crate::{
25    device_map::DeviceMapper,
26    layers::{CausalMasker, RmsNorm},
27    models::mixtral::Config,
28    pipeline::{extract_logits, Cache, NormalModel},
29};
30
31use super::{classifier::XLoraClassifier, NonGranularState, ScalingsMaker, XLoraConfig};
32
33struct Attention {
34    q_proj: Arc<dyn LinearLayerLike + Send + Sync>,
35    k_proj: Arc<dyn LinearLayerLike + Send + Sync>,
36    v_proj: Arc<dyn LinearLayerLike + Send + Sync>,
37    o_proj: Arc<dyn LinearLayerLike + Send + Sync>,
38    num_heads: usize,
39    num_kv_heads: usize,
40    head_dim: usize,
41    rotary_emb: Arc<RotaryEmbedding>,
42    sliding_window: Option<usize>,
43    sdpa_params: SdpaParams,
44}
45
46impl Attention {
47    #[allow(clippy::too_many_arguments)]
48    fn new(
49        rotary_emb: Arc<RotaryEmbedding>,
50        cfg: &Config,
51        vb: ShardedVarBuilder,
52        lora_config: &[((String, String), LoraConfig)],
53        count: &mut usize,
54        ord: &Ordering,
55        mapper: &dyn DeviceMapper,
56        layer_idx: usize,
57        loading_isq: bool,
58        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
59    ) -> Result<Self> {
60        let hidden_sz = cfg.hidden_size;
61        let num_heads = cfg.num_attention_heads;
62        let num_kv_heads = cfg.num_key_value_heads;
63        let head_dim = hidden_sz / num_heads;
64        let q_proj = linear_no_bias(
65            hidden_sz,
66            num_heads * head_dim,
67            mapper.set_device(layer_idx, vb.pp("q_proj"), loading_isq),
68            mapper.set_device(layer_idx, vb.pp("q_proj"), false),
69            lora_config,
70            count,
71            ord,
72            preload_adapters,
73        )?;
74        let k_proj = linear_no_bias(
75            hidden_sz,
76            num_kv_heads * head_dim,
77            mapper.set_device(layer_idx, vb.pp("k_proj"), loading_isq),
78            mapper.set_device(layer_idx, vb.pp("k_proj"), false),
79            lora_config,
80            count,
81            ord,
82            preload_adapters,
83        )?;
84        let v_proj = linear_no_bias(
85            hidden_sz,
86            num_kv_heads * head_dim,
87            mapper.set_device(layer_idx, vb.pp("v_proj"), loading_isq),
88            mapper.set_device(layer_idx, vb.pp("v_proj"), false),
89            lora_config,
90            count,
91            ord,
92            preload_adapters,
93        )?;
94        let o_proj = linear_no_bias(
95            num_heads * head_dim,
96            hidden_sz,
97            mapper.set_device(layer_idx, vb.pp("o_proj"), loading_isq),
98            mapper.set_device(layer_idx, vb.pp("o_proj"), false),
99            lora_config,
100            count,
101            ord,
102            preload_adapters,
103        )?;
104        Ok(Self {
105            q_proj,
106            k_proj,
107            v_proj,
108            o_proj,
109            num_heads,
110            num_kv_heads,
111            head_dim,
112            rotary_emb,
113            sliding_window: cfg.sliding_window,
114            sdpa_params: SdpaParams {
115                n_kv_groups: num_heads / num_kv_heads,
116                use_flash_attn: cfg.use_flash_attn,
117                softcap: None,
118                softmax_scale: 1.0 / (head_dim as f32).sqrt(),
119                sliding_window: cfg.sliding_window,
120            },
121        })
122    }
123
124    #[allow(clippy::too_many_arguments)]
125    fn forward(
126        &self,
127        xs: &Tensor,
128        attention_mask: Option<&Tensor>,
129        seqlen_offsets: &[usize],
130        kv_cache: &mut Option<(Tensor, Tensor)>,
131        scalings: Option<Tensor>,
132        global_scaling_weight: f64,
133        is_scaling_pass: Option<f64>,
134        flash_params: &FlashParams,
135    ) -> Result<Tensor> {
136        let (b_sz, q_len, _) = xs.dims3()?;
137
138        let original_dtype = xs.dtype();
139        let mut xs = xs.clone();
140        if let Some(t) = self.q_proj.quantized_act_type() {
141            xs = xs.to_dtype(t)?;
142        }
143        let mut q = self.q_proj.lora_forward(
144            &xs,
145            scalings.clone(),
146            global_scaling_weight,
147            is_scaling_pass,
148        )?;
149        let mut k = self.k_proj.lora_forward(
150            &xs,
151            scalings.clone(),
152            global_scaling_weight,
153            is_scaling_pass,
154        )?;
155        let mut v = self.v_proj.lora_forward(
156            &xs,
157            scalings.clone(),
158            global_scaling_weight,
159            is_scaling_pass,
160        )?;
161        if self.q_proj.quantized_act_type().is_some() {
162            q = q.to_dtype(original_dtype)?;
163            k = k.to_dtype(original_dtype)?;
164            v = v.to_dtype(original_dtype)?;
165        }
166
167        let (q, k, v) = if q_len != 1 {
168            let q = q
169                .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
170                .transpose(1, 2)?;
171            let k = k
172                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
173                .transpose(1, 2)?;
174            let v = v
175                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
176                .transpose(1, 2)?;
177            (q, k, v)
178        } else {
179            let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
180            let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
181            let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
182            (q, k, v)
183        };
184
185        let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
186
187        let (k, v, attn_mask) = Cache::update_kv_cache_sliding_window(
188            kv_cache,
189            k,
190            v,
191            attention_mask,
192            self.sliding_window,
193            false,
194        )?;
195
196        let mut attn_output = Sdpa.run_attention(
197            &q,
198            &k,
199            &v,
200            attn_mask.as_ref(),
201            Some(flash_params),
202            &self.sdpa_params,
203        )?;
204
205        if let Some(t) = self.q_proj.quantized_act_type() {
206            attn_output = attn_output.to_dtype(t)?;
207        }
208        let mut res = self.o_proj.lora_forward(
209            &attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?,
210            scalings.clone(),
211            global_scaling_weight,
212            is_scaling_pass,
213        )?;
214        if self.q_proj.quantized_act_type().is_some() {
215            res = res.to_dtype(original_dtype)?;
216        }
217        Ok(res)
218    }
219}
220
221#[derive(Clone)]
222struct BlockSparseTop2MLP {
223    w1: Arc<dyn LinearLayerLike + Send + Sync>,
224    w2: Arc<dyn LinearLayerLike + Send + Sync>,
225    w3: Arc<dyn LinearLayerLike + Send + Sync>,
226    act_fn: Activation,
227}
228
229impl BlockSparseTop2MLP {
230    #[allow(clippy::too_many_arguments)]
231    fn new(
232        cfg: &Config,
233        vb: ShardedVarBuilder,
234        lora_config: &[((String, String), LoraConfig)],
235        count: &mut usize,
236        ord: &Ordering,
237        mapper: &dyn DeviceMapper,
238        layer_idx: usize,
239        loading_isq: bool,
240        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
241    ) -> Result<Self> {
242        let hidden_sz = cfg.hidden_size;
243        let intermediate_sz = cfg.intermediate_size;
244        let w1 = linear_no_bias(
245            hidden_sz,
246            intermediate_sz,
247            mapper.set_device(layer_idx, vb.pp("w1"), loading_isq),
248            mapper.set_device(layer_idx, vb.pp("w1"), false),
249            lora_config,
250            count,
251            ord,
252            preload_adapters,
253        )?;
254        let w2 = linear_no_bias(
255            intermediate_sz,
256            hidden_sz,
257            mapper.set_device(layer_idx, vb.pp("w2"), loading_isq),
258            mapper.set_device(layer_idx, vb.pp("w2"), false),
259            lora_config,
260            count,
261            ord,
262            preload_adapters,
263        )?;
264        let w3 = linear_no_bias(
265            hidden_sz,
266            intermediate_sz,
267            mapper.set_device(layer_idx, vb.pp("w3"), loading_isq),
268            mapper.set_device(layer_idx, vb.pp("w3"), false),
269            lora_config,
270            count,
271            ord,
272            preload_adapters,
273        )?;
274        Ok(Self {
275            w1,
276            w2,
277            w3,
278            act_fn: cfg.hidden_act,
279        })
280    }
281
282    fn forward(
283        &self,
284        xs: &Tensor,
285        scalings: Option<Tensor>,
286        global_scaling_weight: f64,
287        is_scaling_pass: Option<f64>,
288    ) -> Result<Tensor> {
289        let original_dtype = xs.dtype();
290        let mut xs = xs.clone();
291        if let Some(t) = self.w1.quantized_act_type() {
292            xs = xs.to_dtype(t)?;
293        }
294        let lhs = self
295            .w1
296            .lora_forward(
297                &xs,
298                scalings.clone(),
299                global_scaling_weight,
300                is_scaling_pass,
301            )?
302            .apply(&self.act_fn)?;
303        let rhs = self.w3.lora_forward(
304            &xs,
305            scalings.clone(),
306            global_scaling_weight,
307            is_scaling_pass,
308        )?;
309        let mut res = self.w2.lora_forward(
310            &(lhs * rhs)?,
311            scalings.clone(),
312            global_scaling_weight,
313            is_scaling_pass,
314        )?;
315        if self.w1.quantized_act_type().is_some() {
316            res = res.to_dtype(original_dtype)?;
317        }
318        Ok(res)
319    }
320}
321
322#[derive(Clone)]
323struct SparseMoeBlock {
324    gate: Arc<dyn LinearLayerLike + Send + Sync>,
325    experts: Vec<BlockSparseTop2MLP>,
326    num_experts_per_tok: usize,
327}
328
329impl SparseMoeBlock {
330    #[allow(clippy::too_many_arguments)]
331    fn new(
332        cfg: &Config,
333        vb: ShardedVarBuilder,
334        lora_config: &[((String, String), LoraConfig)],
335        count: &mut usize,
336        ord: &Ordering,
337        mapper: &dyn DeviceMapper,
338        layer_idx: usize,
339        loading_isq: bool,
340        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
341    ) -> Result<Self> {
342        let gate = linear_no_bias(
343            cfg.hidden_size,
344            cfg.num_local_experts,
345            mapper.set_device(layer_idx, vb.pp("gate"), loading_isq),
346            mapper.set_device(layer_idx, vb.pp("gate"), false),
347            lora_config,
348            count,
349            ord,
350            preload_adapters,
351        )?;
352        let mut experts = Vec::with_capacity(cfg.num_local_experts);
353        let vb = vb.pp("experts");
354        for idx in 0..cfg.num_local_experts {
355            let expert = BlockSparseTop2MLP::new(
356                cfg,
357                vb.pp(idx),
358                lora_config,
359                count,
360                ord,
361                mapper,
362                layer_idx,
363                loading_isq,
364                preload_adapters,
365            )?;
366            experts.push(expert)
367        }
368        Ok(SparseMoeBlock {
369            gate,
370            experts,
371            num_experts_per_tok: cfg.num_experts_per_tok,
372        })
373    }
374
375    fn forward(
376        &self,
377        xs: &Tensor,
378        scalings: Option<Tensor>,
379        global_scaling_weight: f64,
380        is_scaling_pass: Option<f64>,
381    ) -> Result<Tensor> {
382        let (b_size, seq_len, hidden_dim) = xs.dims3()?;
383        let xs = xs.reshape(((), hidden_dim))?;
384
385        let original_dtype = xs.dtype();
386        let mut xs = xs.clone();
387        if let Some(t) = self.gate.quantized_act_type() {
388            xs = xs.to_dtype(t)?;
389        }
390        let mut router_logits = self.gate.lora_forward(
391            &xs,
392            scalings.clone(),
393            global_scaling_weight,
394            is_scaling_pass,
395        )?;
396        if self.gate.quantized_act_type().is_some() {
397            router_logits = router_logits.to_dtype(original_dtype)?;
398        }
399
400        let routing_weights = candle_nn::ops::softmax_last_dim(&router_logits)?;
401
402        // In order to extract topk, we extract the data from the tensor and manipulate it
403        // directly. Maybe we will want to use some custom ops instead at some point.
404        let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::<f32>()?;
405
406        // routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
407        // top_x contains the row indexes to evaluate for each expert.
408        let mut top_x = vec![vec![]; self.experts.len()];
409        let mut selected_rws = vec![vec![]; self.experts.len()];
410        for (row_idx, rw) in routing_weights.iter().enumerate() {
411            let mut dst = (0..rw.len() as u32).collect::<Vec<u32>>();
412            dst.sort_by(|&i, &j| rw[j as usize].total_cmp(&rw[i as usize]));
413            let mut sum_routing_weights = 0f32;
414            for &expert_idx in dst.iter().take(self.num_experts_per_tok) {
415                let expert_idx = expert_idx as usize;
416                let routing_weight = rw[expert_idx];
417                sum_routing_weights += routing_weight;
418                top_x[expert_idx].push(row_idx as u32);
419            }
420            for &expert_idx in dst.iter().take(self.num_experts_per_tok) {
421                let expert_idx = expert_idx as usize;
422                let routing_weight = rw[expert_idx];
423                selected_rws[expert_idx].push(routing_weight / sum_routing_weights)
424            }
425        }
426
427        // routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
428        // expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
429
430        let mut ys = xs.zeros_like()?;
431        for (expert_idx, expert_layer) in self.experts.iter().enumerate() {
432            let top_x = &top_x[expert_idx];
433            if top_x.is_empty() {
434                continue;
435            }
436            let top_x = Tensor::new(top_x.as_slice(), xs.device())?;
437            let selected_rws =
438                Tensor::new(selected_rws[expert_idx].as_slice(), xs.device())?.reshape(((), 1))?;
439            // Index the correct hidden states and compute the expert hidden state for
440            // the current expert. We need to make sure to multiply the output hidden
441            // states by `routing_weights` on the corresponding tokens (top-1 and top-2)
442            let current_state = xs.index_select(&top_x, 0)?.reshape(((), hidden_dim))?;
443            // current_hidden_states = expert_layer(current_state, routing_weights[top_x_list, idx_list, None])
444            let current_hidden_states = expert_layer.forward(
445                &current_state,
446                scalings.clone(),
447                global_scaling_weight,
448                is_scaling_pass,
449            )?;
450            let current_hidden_states = current_hidden_states.broadcast_mul(&selected_rws)?;
451            ys = ys.index_add(&top_x, &current_hidden_states, 0)?;
452        }
453
454        let ys = ys.reshape((b_size, seq_len, hidden_dim))?;
455        Ok(ys)
456    }
457}
458
459struct DecoderLayer {
460    self_attn: Attention,
461    block_sparse_moe: SparseMoeBlock,
462    input_layernorm: RmsNorm,
463    post_attention_layernorm: RmsNorm,
464}
465
466impl DecoderLayer {
467    #[allow(clippy::too_many_arguments)]
468    fn new(
469        rotary_emb: Arc<RotaryEmbedding>,
470        cfg: &Config,
471        vb: ShardedVarBuilder,
472        lora_config: &[((String, String), LoraConfig)],
473        count: &mut usize,
474        ord: &Ordering,
475        mapper: &dyn DeviceMapper,
476        layer_idx: usize,
477        loading_isq: bool,
478        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
479    ) -> Result<Self> {
480        let self_attn = Attention::new(
481            rotary_emb,
482            cfg,
483            vb.pp("self_attn"),
484            lora_config,
485            count,
486            ord,
487            mapper,
488            layer_idx,
489            loading_isq,
490            preload_adapters,
491        )?;
492        let block_sparse_moe = SparseMoeBlock::new(
493            cfg,
494            vb.pp("block_sparse_moe"),
495            lora_config,
496            count,
497            ord,
498            mapper,
499            layer_idx,
500            loading_isq,
501            preload_adapters,
502        )?;
503        let input_layernorm = RmsNorm::new(
504            cfg.hidden_size,
505            cfg.rms_norm_eps,
506            mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
507        )?;
508        let post_attention_layernorm = RmsNorm::new(
509            cfg.hidden_size,
510            cfg.rms_norm_eps,
511            mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
512        )?;
513        Ok(Self {
514            self_attn,
515            block_sparse_moe,
516            input_layernorm,
517            post_attention_layernorm,
518        })
519    }
520
521    #[allow(clippy::too_many_arguments)]
522    fn forward(
523        &self,
524        xs: &Tensor,
525        attention_mask: Option<&Tensor>,
526        seqlen_offsets: &[usize],
527        kv_cache: &mut Option<(Tensor, Tensor)>,
528        scalings: Option<Tensor>,
529        global_scaling_weight: f64,
530        is_scaling_pass: Option<f64>,
531        flash_params: &FlashParams,
532    ) -> Result<Tensor> {
533        let residual = xs;
534        let xs = self.input_layernorm.forward(xs)?;
535        let xs = self.self_attn.forward(
536            &xs,
537            attention_mask,
538            seqlen_offsets,
539            kv_cache,
540            scalings.clone(),
541            global_scaling_weight,
542            is_scaling_pass,
543            flash_params,
544        )?;
545        let xs = (xs + residual)?;
546        let residual = &xs;
547        let xs = self
548            .block_sparse_moe
549            .forward(
550                &xs.apply(&self.post_attention_layernorm)?,
551                scalings.clone(),
552                global_scaling_weight,
553                is_scaling_pass,
554            )?
555            .to_dtype(residual.dtype())?;
556        residual + xs
557    }
558}
559
560pub struct XLoraModel {
561    embed_tokens: candle_nn::Embedding,
562    layers: Vec<DecoderLayer>,
563    norm: RmsNorm,
564    lm_head: Arc<dyn LinearLayerLike + Send + Sync>,
565    sliding_window: Option<usize>,
566    device: Device,
567    cache: EitherCache,
568    dtype: DType,
569    max_seq_len: usize,
570    xlora_classifier: Option<XLoraClassifier>,
571    mapper: Box<dyn DeviceMapper + Send + Sync>,
572    cfg: ModelConfigMetadata,
573}
574
575impl XLoraModel {
576    #[allow(clippy::too_many_arguments)]
577    pub fn new(
578        cfg: &Config,
579        vb: ShardedVarBuilder,
580        lora_config: &[((String, String), LoraConfig)],
581        xlora_config: Option<XLoraConfig>,
582        xlora_ordering: Ordering,
583        is_gptx: bool,
584        normal_loading_metadata: NormalLoadingMetadata,
585        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
586    ) -> Result<Self> {
587        if let Some(ref quant_cfg) = &cfg.quantization_config {
588            tracing::info!(
589                "Using {} quantization: {}.",
590                quant_cfg.name(),
591                quant_cfg.get_bits_name(&vb)
592            );
593        }
594        let mapper = normal_loading_metadata.mapper;
595        let vb_m = vb.pp("model");
596
597        let embed_tokens = layers::embedding(
598            cfg.vocab_size,
599            cfg.hidden_size,
600            mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
601            &cfg.quantization_config,
602        )?;
603        let head_dim = cfg.hidden_size / cfg.num_attention_heads;
604        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
605        let vb_l = vb_m.pp("layers");
606        let mut ropes = HashMap::new();
607        for layer_idx in 0..cfg.num_hidden_layers {
608            let device = mapper
609                .device_for(layer_idx, false)
610                .unwrap_or(&normal_loading_metadata.real_device);
611            ropes.insert(
612                device.location(),
613                Arc::new(RotaryEmbedding::new(
614                    cfg.rope_theta as f32,
615                    head_dim,
616                    cfg.max_position_embeddings,
617                    device,
618                    is_gptx,
619                    vb_m.dtype(),
620                )?),
621            );
622        }
623
624        let mut count = 0;
625        for layer_idx in NiceProgressBar::<_, 'b'>(
626            0..cfg.num_hidden_layers,
627            "Loading repeating layers",
628            &normal_loading_metadata.multi_progress,
629        ) {
630            let device = mapper
631                .device_for(layer_idx, false)
632                .unwrap_or(&normal_loading_metadata.real_device);
633            let rotary_emb = ropes
634                .get(&device.location())
635                .expect("No RoPE for device location!")
636                .clone();
637            let layer = DecoderLayer::new(
638                rotary_emb.clone(),
639                cfg,
640                vb_l.pp(layer_idx),
641                lora_config,
642                &mut count,
643                &xlora_ordering,
644                &*mapper,
645                layer_idx,
646                normal_loading_metadata.loading_isq,
647                preload_adapters,
648            )?;
649            layers.push(layer)
650        }
651        if xlora_config.is_none() && preload_adapters.is_none() {
652            // We are now a LoRA model so we must merge the weights
653            info!("Merging LoRA adapters.");
654            for layer in layers.iter_mut().tqdm() {
655                Arc::get_mut(&mut layer.self_attn.k_proj)
656                    .unwrap()
657                    .merge_weights()?;
658                Arc::get_mut(&mut layer.self_attn.o_proj)
659                    .unwrap()
660                    .merge_weights()?;
661                Arc::get_mut(&mut layer.self_attn.q_proj)
662                    .unwrap()
663                    .merge_weights()?;
664                Arc::get_mut(&mut layer.self_attn.v_proj)
665                    .unwrap()
666                    .merge_weights()?;
667
668                Arc::get_mut(&mut layer.block_sparse_moe.gate)
669                    .unwrap()
670                    .merge_weights()?;
671                for expert in layer.block_sparse_moe.experts.iter_mut() {
672                    Arc::get_mut(&mut expert.w1).unwrap().merge_weights()?;
673                    Arc::get_mut(&mut expert.w2).unwrap().merge_weights()?;
674                    Arc::get_mut(&mut expert.w3).unwrap().merge_weights()?;
675                }
676            }
677        }
678        let norm = RmsNorm::new(
679            cfg.hidden_size,
680            cfg.rms_norm_eps,
681            mapper.set_nm_device(vb_m.pp("norm"), false),
682        )?;
683        let lm_head = linear_no_bias(
684            cfg.hidden_size,
685            cfg.vocab_size,
686            mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
687            mapper.set_nm_device(vb.pp("lm_head"), false),
688            lora_config,
689            &mut count,
690            &xlora_ordering,
691            preload_adapters,
692        )?;
693        if xlora_config.is_some() && lm_head.is_lora() {
694            // This is why we can pass dummy values (..., None, 1.0, None)?
695            candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA.");
696        }
697        Ok(Self {
698            embed_tokens,
699            layers,
700            norm,
701            lm_head,
702            sliding_window: cfg.sliding_window,
703            device: normal_loading_metadata.real_device,
704            dtype: vb.dtype(),
705            cache: EitherCache::Full(Cache::new(cfg.num_hidden_layers, false)),
706            max_seq_len: cfg.max_position_embeddings,
707            xlora_classifier: xlora_config.map(|xlora_config| {
708                XLoraClassifier::new(xlora_config, count, lora_config.len(), vb, false).unwrap()
709            }),
710            mapper,
711            cfg: ModelConfigMetadata {
712                max_seq_len: cfg.max_position_embeddings,
713                num_layers: cfg.num_hidden_layers,
714                hidden_size: cfg.hidden_size,
715                num_kv_heads: cfg.num_key_value_heads,
716                num_attn_heads: cfg.num_attention_heads,
717                sliding_window: cfg.sliding_window,
718                k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
719                v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
720            },
721        })
722    }
723
724    #[allow(clippy::too_many_arguments)]
725    fn inner_forward(
726        &self,
727        input_ids: &Tensor,
728        seqlen_offsets: &[usize],
729        scalings: Option<Tensor>,
730        is_full_pass: bool,
731        no_kv_cache: bool,
732        is_scaling_pass: Option<f64>,
733        flash_params: &FlashParams,
734    ) -> Result<Tensor> {
735        let mut cache = if is_full_pass {
736            if no_kv_cache {
737                let mut new_cache = Vec::new();
738                for _ in 0..self.cache.full().xlora_lock().len() {
739                    new_cache.push(None);
740                }
741
742                self.cache.full().xlora_lock().clone_from(&new_cache);
743            }
744            self.cache.full().xlora_lock()
745        } else {
746            self.cache.full().lock()
747        };
748        let mut xs = self.embed_tokens.forward(input_ids)?;
749        let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
750            input_ids,
751            &*cache,
752            self.sliding_window,
753            xs.dtype(),
754            self.cfg.num_attn_heads,
755        )?;
756        for (i, layer) in self.layers.iter().enumerate() {
757            xs = self.mapper.map(xs, i)?;
758            xs = layer.forward(
759                &xs,
760                attention_mask
761                    .as_ref()
762                    .map(|m| m.to_device(xs.device()).unwrap())
763                    .as_ref(),
764                seqlen_offsets,
765                &mut cache[i],
766                scalings.clone(),
767                self.xlora_classifier
768                    .as_ref()
769                    .map(|classifier| classifier.get_global_scaling_weight())
770                    .unwrap_or(1.0),
771                is_scaling_pass,
772                flash_params,
773            )?
774        }
775        let xs = xs.to_device(&self.device)?;
776        xs.apply(&self.norm)
777    }
778
779    #[allow(clippy::too_many_arguments)]
780    pub fn forward(
781        &self,
782        input_ids: &Tensor,
783        input_ids_full: &Tensor,
784        seqlen_offsets: &[usize],
785        seqlen_offsets_full: &[usize],
786        no_kv_cache: bool,
787        non_granular_state: &Option<NonGranularState>,
788        context_lens: Vec<(usize, usize)>,
789        flash_params: &FlashParams,
790        flash_params_full: &FlashParams,
791    ) -> Result<Tensor> {
792        if self.xlora_classifier.is_some() {
793            let scalings = self.get_scalings(
794                input_ids,
795                input_ids_full,
796                seqlen_offsets,
797                seqlen_offsets_full,
798                no_kv_cache,
799                non_granular_state,
800                &vec![usize::MAX; context_lens.len()],
801                flash_params,
802                flash_params_full,
803            )?;
804
805            if no_kv_cache {
806                let mut res = self
807                    .inner_forward(
808                        input_ids_full,
809                        seqlen_offsets_full,
810                        Some(scalings),
811                        true,
812                        no_kv_cache,
813                        None,
814                        flash_params_full,
815                    )?
816                    .contiguous()?;
817                if let Some(t) = self.lm_head.quantized_act_type() {
818                    res = res.to_dtype(t)?;
819                }
820                extract_logits(
821                    &self.lm_head.lora_forward(&res, None, 1.0, None)?,
822                    context_lens,
823                )
824            } else {
825                // is_full_pass=true is ok because no_kv_cache=false
826                let mut res = self
827                    .inner_forward(
828                        input_ids,
829                        seqlen_offsets,
830                        Some(scalings),
831                        true,
832                        no_kv_cache,
833                        None,
834                        flash_params,
835                    )?
836                    .contiguous()?;
837                if let Some(t) = self.lm_head.quantized_act_type() {
838                    res = res.to_dtype(t)?;
839                }
840                extract_logits(
841                    &self.lm_head.lora_forward(&res, None, 1.0, None)?,
842                    context_lens,
843                )
844            }
845        } else {
846            let mut res = self
847                .inner_forward(
848                    input_ids,
849                    seqlen_offsets,
850                    None,
851                    false,
852                    no_kv_cache,
853                    None,
854                    flash_params,
855                )?
856                .contiguous()?;
857            if let Some(t) = self.lm_head.quantized_act_type() {
858                res = res.to_dtype(t)?;
859            }
860            extract_logits(
861                &self.lm_head.lora_forward(&res, None, 1.0, None)?,
862                context_lens,
863            )
864        }
865    }
866}
867
868impl IsqModel for XLoraModel {
869    fn get_layers(
870        &mut self,
871    ) -> (
872        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
873        &dyn DeviceMapper,
874    ) {
875        let mut tensors = Vec::new();
876        tensors.push((Arc::get_mut(&mut self.lm_head).unwrap().quant_inner(), None));
877        for (i, layer) in self.layers.iter_mut().enumerate() {
878            tensors.push((
879                Arc::get_mut(&mut layer.self_attn.q_proj)
880                    .unwrap()
881                    .quant_inner(),
882                Some(i),
883            ));
884            tensors.push((
885                Arc::get_mut(&mut layer.self_attn.k_proj)
886                    .unwrap()
887                    .quant_inner(),
888                Some(i),
889            ));
890            tensors.push((
891                Arc::get_mut(&mut layer.self_attn.v_proj)
892                    .unwrap()
893                    .quant_inner(),
894                Some(i),
895            ));
896            tensors.push((
897                Arc::get_mut(&mut layer.self_attn.o_proj)
898                    .unwrap()
899                    .quant_inner(),
900                Some(i),
901            ));
902            tensors.push((
903                Arc::get_mut(&mut layer.block_sparse_moe.gate)
904                    .unwrap()
905                    .quant_inner(),
906                Some(i),
907            ));
908            for expert in &mut layer.block_sparse_moe.experts {
909                tensors.push((Arc::get_mut(&mut expert.w1).unwrap().quant_inner(), Some(i)));
910                tensors.push((Arc::get_mut(&mut expert.w2).unwrap().quant_inner(), Some(i)));
911                tensors.push((Arc::get_mut(&mut expert.w3).unwrap().quant_inner(), Some(i)));
912            }
913        }
914        (tensors, &*self.mapper)
915    }
916
917    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
918        panic!("Cannot generate UQFF for an adapter model.")
919    }
920}
921
922impl NormalModel for XLoraModel {
923    fn forward(
924        &self,
925        _input_ids: &Tensor,
926        _seqlen_offsets: &[usize],
927        _context_lens: Vec<(usize, usize)>,
928        _position_ids: Vec<usize>,
929        _metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
930        _flash_params: &FlashParams,
931    ) -> Result<Tensor> {
932        unreachable!()
933    }
934    fn xlora_forward(
935        &self,
936        input_ids: &Tensor,
937        input_ids_full: &Tensor,
938        seqlen_offsets: &[usize],
939        seqlen_offsets_full: &[usize],
940        no_kv_cache: bool,
941        non_granular_state: &Option<crate::xlora_models::NonGranularState>,
942        context_lens: Vec<(usize, usize)>,
943        _position_ids: Vec<usize>,
944        flash_params: &FlashParams,
945        flash_params_full: &FlashParams,
946    ) -> Result<Tensor> {
947        self.forward(
948            input_ids,
949            input_ids_full,
950            seqlen_offsets,
951            seqlen_offsets_full,
952            no_kv_cache,
953            non_granular_state,
954            context_lens,
955            flash_params,
956            flash_params_full,
957        )
958    }
959    fn cache(&self) -> &EitherCache {
960        &self.cache
961    }
962    fn cache_mut(&mut self) -> &mut EitherCache {
963        &mut self.cache
964    }
965    fn device(&self) -> &Device {
966        &self.device
967    }
968    fn is_xlora(&self) -> bool {
969        true
970    }
971    fn max_seq_len(&self) -> usize {
972        self.max_seq_len
973    }
974    fn config(&self) -> &ModelConfigMetadata {
975        &self.cfg
976    }
977}
978
979impl ScalingsMaker for XLoraModel {
980    fn dtype(&self) -> DType {
981        self.dtype
982    }
983    fn get_cache(&self) -> &EitherCache {
984        &self.cache
985    }
986    fn get_classifier(&self) -> &XLoraClassifier {
987        self.xlora_classifier.as_ref().unwrap()
988    }
989    fn forward(
990        &self,
991        input_ids: &Tensor,
992        seqlen_offsets: &[usize],
993        scalings: Tensor,
994        is_full_pass: bool,
995        no_kv_cache: bool,
996        is_scaling_pass: Option<f64>,
997        _context_lens: &[usize],
998        flash_params: &FlashParams,
999    ) -> Result<Tensor> {
1000        self.inner_forward(
1001            input_ids,
1002            seqlen_offsets,
1003            Some(scalings),
1004            is_full_pass,
1005            no_kv_cache,
1006            is_scaling_pass,
1007            flash_params,
1008        )
1009    }
1010}
1011
1012impl AnyMoeBaseModelMixin for XLoraModel {}