mistralrs_core/xlora_models/
mixtral.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use crate::{
4    amoe::AnyMoeBaseModelMixin,
5    attention::SdpaParams,
6    layers::{self, Activation, RotaryEmbedding, Sdpa},
7    lora::{linear_no_bias, LinearLayerLike, LoraConfig, Ordering},
8    paged_attention::ModelConfigMetadata,
9    pipeline::{
10        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
11        EitherCache, IsqModel, NormalLoadingMetadata,
12    },
13    utils::progress::NiceProgressBar,
14};
15/// Mixtral Model
16/// https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py
17/// https://mistral.ai/news/mixtral-of-experts/
18use candle_core::{DType, Device, Module, Result, Tensor};
19use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
20use std::{collections::HashMap, sync::Arc};
21use tqdm::Iter;
22use tracing::info;
23
24use crate::{
25    device_map::DeviceMapper,
26    layers::{CausalMasker, RmsNorm},
27    models::mixtral::Config,
28    pipeline::{extract_logits, Cache, NormalModel},
29};
30
31use super::{classifier::XLoraClassifier, NonGranularState, ScalingsMaker, XLoraConfig};
32
33struct Attention {
34    q_proj: Arc<dyn LinearLayerLike + Send + Sync>,
35    k_proj: Arc<dyn LinearLayerLike + Send + Sync>,
36    v_proj: Arc<dyn LinearLayerLike + Send + Sync>,
37    o_proj: Arc<dyn LinearLayerLike + Send + Sync>,
38    num_heads: usize,
39    num_kv_heads: usize,
40    head_dim: usize,
41    rotary_emb: Arc<RotaryEmbedding>,
42    sliding_window: Option<usize>,
43    sdpa_params: SdpaParams,
44}
45
46impl Attention {
47    #[allow(clippy::too_many_arguments)]
48    fn new(
49        rotary_emb: Arc<RotaryEmbedding>,
50        cfg: &Config,
51        vb: ShardedVarBuilder,
52        lora_config: &[((String, String), LoraConfig)],
53        count: &mut usize,
54        ord: &Ordering,
55        mapper: &dyn DeviceMapper,
56        layer_idx: usize,
57        loading_isq: bool,
58        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
59    ) -> Result<Self> {
60        let hidden_sz = cfg.hidden_size;
61        let num_heads = cfg.num_attention_heads;
62        let num_kv_heads = cfg.num_key_value_heads;
63        let head_dim = hidden_sz / num_heads;
64        let q_proj = linear_no_bias(
65            hidden_sz,
66            num_heads * head_dim,
67            mapper.set_device(layer_idx, vb.pp("q_proj"), loading_isq),
68            mapper.set_device(layer_idx, vb.pp("q_proj"), false),
69            lora_config,
70            count,
71            ord,
72            preload_adapters,
73        )?;
74        let k_proj = linear_no_bias(
75            hidden_sz,
76            num_kv_heads * head_dim,
77            mapper.set_device(layer_idx, vb.pp("k_proj"), loading_isq),
78            mapper.set_device(layer_idx, vb.pp("k_proj"), false),
79            lora_config,
80            count,
81            ord,
82            preload_adapters,
83        )?;
84        let v_proj = linear_no_bias(
85            hidden_sz,
86            num_kv_heads * head_dim,
87            mapper.set_device(layer_idx, vb.pp("v_proj"), loading_isq),
88            mapper.set_device(layer_idx, vb.pp("v_proj"), false),
89            lora_config,
90            count,
91            ord,
92            preload_adapters,
93        )?;
94        let o_proj = linear_no_bias(
95            num_heads * head_dim,
96            hidden_sz,
97            mapper.set_device(layer_idx, vb.pp("o_proj"), loading_isq),
98            mapper.set_device(layer_idx, vb.pp("o_proj"), false),
99            lora_config,
100            count,
101            ord,
102            preload_adapters,
103        )?;
104        Ok(Self {
105            q_proj,
106            k_proj,
107            v_proj,
108            o_proj,
109            num_heads,
110            num_kv_heads,
111            head_dim,
112            rotary_emb,
113            sliding_window: cfg.sliding_window,
114            sdpa_params: SdpaParams {
115                n_kv_groups: num_heads / num_kv_heads,
116                softcap: None,
117                softmax_scale: 1.0 / (head_dim as f32).sqrt(),
118                sliding_window: cfg.sliding_window,
119            },
120        })
121    }
122
123    #[allow(clippy::too_many_arguments)]
124    fn forward(
125        &self,
126        xs: &Tensor,
127        attention_mask: Option<&Tensor>,
128        seqlen_offsets: &[usize],
129        kv_cache: &mut Option<(Tensor, Tensor)>,
130        scalings: Option<Tensor>,
131        global_scaling_weight: f64,
132        is_scaling_pass: Option<f64>,
133        flash_params: &FlashParams,
134    ) -> Result<Tensor> {
135        let (b_sz, q_len, _) = xs.dims3()?;
136
137        let original_dtype = xs.dtype();
138        let mut xs = xs.clone();
139        if let Some(t) = self.q_proj.quantized_act_type() {
140            xs = xs.to_dtype(t)?;
141        }
142        let mut q = self.q_proj.lora_forward(
143            &xs,
144            scalings.clone(),
145            global_scaling_weight,
146            is_scaling_pass,
147        )?;
148        let mut k = self.k_proj.lora_forward(
149            &xs,
150            scalings.clone(),
151            global_scaling_weight,
152            is_scaling_pass,
153        )?;
154        let mut v = self.v_proj.lora_forward(
155            &xs,
156            scalings.clone(),
157            global_scaling_weight,
158            is_scaling_pass,
159        )?;
160        if self.q_proj.quantized_act_type().is_some() {
161            q = q.to_dtype(original_dtype)?;
162            k = k.to_dtype(original_dtype)?;
163            v = v.to_dtype(original_dtype)?;
164        }
165
166        let (q, k, v) = if q_len != 1 {
167            let q = q
168                .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
169                .transpose(1, 2)?;
170            let k = k
171                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
172                .transpose(1, 2)?;
173            let v = v
174                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
175                .transpose(1, 2)?;
176            (q, k, v)
177        } else {
178            let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
179            let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
180            let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
181            (q, k, v)
182        };
183
184        let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
185
186        let (k, v, attn_mask) = Cache::update_kv_cache_sliding_window(
187            kv_cache,
188            k,
189            v,
190            attention_mask,
191            self.sliding_window,
192            false,
193        )?;
194
195        let mut attn_output = Sdpa.run_attention(
196            &q,
197            &k,
198            &v,
199            attn_mask.as_ref(),
200            Some(flash_params),
201            &self.sdpa_params,
202        )?;
203
204        if let Some(t) = self.q_proj.quantized_act_type() {
205            attn_output = attn_output.to_dtype(t)?;
206        }
207        let mut res = self.o_proj.lora_forward(
208            &attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?,
209            scalings.clone(),
210            global_scaling_weight,
211            is_scaling_pass,
212        )?;
213        if self.q_proj.quantized_act_type().is_some() {
214            res = res.to_dtype(original_dtype)?;
215        }
216        Ok(res)
217    }
218}
219
220#[derive(Clone)]
221struct BlockSparseTop2MLP {
222    w1: Arc<dyn LinearLayerLike + Send + Sync>,
223    w2: Arc<dyn LinearLayerLike + Send + Sync>,
224    w3: Arc<dyn LinearLayerLike + Send + Sync>,
225    act_fn: Activation,
226}
227
228impl BlockSparseTop2MLP {
229    #[allow(clippy::too_many_arguments)]
230    fn new(
231        cfg: &Config,
232        vb: ShardedVarBuilder,
233        lora_config: &[((String, String), LoraConfig)],
234        count: &mut usize,
235        ord: &Ordering,
236        mapper: &dyn DeviceMapper,
237        layer_idx: usize,
238        loading_isq: bool,
239        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
240    ) -> Result<Self> {
241        let hidden_sz = cfg.hidden_size;
242        let intermediate_sz = cfg.intermediate_size;
243        let w1 = linear_no_bias(
244            hidden_sz,
245            intermediate_sz,
246            mapper.set_device(layer_idx, vb.pp("w1"), loading_isq),
247            mapper.set_device(layer_idx, vb.pp("w1"), false),
248            lora_config,
249            count,
250            ord,
251            preload_adapters,
252        )?;
253        let w2 = linear_no_bias(
254            intermediate_sz,
255            hidden_sz,
256            mapper.set_device(layer_idx, vb.pp("w2"), loading_isq),
257            mapper.set_device(layer_idx, vb.pp("w2"), false),
258            lora_config,
259            count,
260            ord,
261            preload_adapters,
262        )?;
263        let w3 = linear_no_bias(
264            hidden_sz,
265            intermediate_sz,
266            mapper.set_device(layer_idx, vb.pp("w3"), loading_isq),
267            mapper.set_device(layer_idx, vb.pp("w3"), false),
268            lora_config,
269            count,
270            ord,
271            preload_adapters,
272        )?;
273        Ok(Self {
274            w1,
275            w2,
276            w3,
277            act_fn: cfg.hidden_act,
278        })
279    }
280
281    fn forward(
282        &self,
283        xs: &Tensor,
284        scalings: Option<Tensor>,
285        global_scaling_weight: f64,
286        is_scaling_pass: Option<f64>,
287    ) -> Result<Tensor> {
288        let original_dtype = xs.dtype();
289        let mut xs = xs.clone();
290        if let Some(t) = self.w1.quantized_act_type() {
291            xs = xs.to_dtype(t)?;
292        }
293        let lhs = self
294            .w1
295            .lora_forward(
296                &xs,
297                scalings.clone(),
298                global_scaling_weight,
299                is_scaling_pass,
300            )?
301            .apply(&self.act_fn)?;
302        let rhs = self.w3.lora_forward(
303            &xs,
304            scalings.clone(),
305            global_scaling_weight,
306            is_scaling_pass,
307        )?;
308        let mut res = self.w2.lora_forward(
309            &(lhs * rhs)?,
310            scalings.clone(),
311            global_scaling_weight,
312            is_scaling_pass,
313        )?;
314        if self.w1.quantized_act_type().is_some() {
315            res = res.to_dtype(original_dtype)?;
316        }
317        Ok(res)
318    }
319}
320
321#[derive(Clone)]
322struct SparseMoeBlock {
323    gate: Arc<dyn LinearLayerLike + Send + Sync>,
324    experts: Vec<BlockSparseTop2MLP>,
325    num_experts_per_tok: usize,
326}
327
328impl SparseMoeBlock {
329    #[allow(clippy::too_many_arguments)]
330    fn new(
331        cfg: &Config,
332        vb: ShardedVarBuilder,
333        lora_config: &[((String, String), LoraConfig)],
334        count: &mut usize,
335        ord: &Ordering,
336        mapper: &dyn DeviceMapper,
337        layer_idx: usize,
338        loading_isq: bool,
339        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
340    ) -> Result<Self> {
341        let gate = linear_no_bias(
342            cfg.hidden_size,
343            cfg.num_local_experts,
344            mapper.set_device(layer_idx, vb.pp("gate"), loading_isq),
345            mapper.set_device(layer_idx, vb.pp("gate"), false),
346            lora_config,
347            count,
348            ord,
349            preload_adapters,
350        )?;
351        let mut experts = Vec::with_capacity(cfg.num_local_experts);
352        let vb = vb.pp("experts");
353        for idx in 0..cfg.num_local_experts {
354            let expert = BlockSparseTop2MLP::new(
355                cfg,
356                vb.pp(idx),
357                lora_config,
358                count,
359                ord,
360                mapper,
361                layer_idx,
362                loading_isq,
363                preload_adapters,
364            )?;
365            experts.push(expert)
366        }
367        Ok(SparseMoeBlock {
368            gate,
369            experts,
370            num_experts_per_tok: cfg.num_experts_per_tok,
371        })
372    }
373
374    fn forward(
375        &self,
376        xs: &Tensor,
377        scalings: Option<Tensor>,
378        global_scaling_weight: f64,
379        is_scaling_pass: Option<f64>,
380    ) -> Result<Tensor> {
381        let (b_size, seq_len, hidden_dim) = xs.dims3()?;
382        let xs = xs.reshape(((), hidden_dim))?;
383
384        let original_dtype = xs.dtype();
385        let mut xs = xs.clone();
386        if let Some(t) = self.gate.quantized_act_type() {
387            xs = xs.to_dtype(t)?;
388        }
389        let mut router_logits = self.gate.lora_forward(
390            &xs,
391            scalings.clone(),
392            global_scaling_weight,
393            is_scaling_pass,
394        )?;
395        if self.gate.quantized_act_type().is_some() {
396            router_logits = router_logits.to_dtype(original_dtype)?;
397        }
398
399        let routing_weights = candle_nn::ops::softmax_last_dim(&router_logits)?;
400
401        // In order to extract topk, we extract the data from the tensor and manipulate it
402        // directly. Maybe we will want to use some custom ops instead at some point.
403        let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::<f32>()?;
404
405        // routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
406        // top_x contains the row indexes to evaluate for each expert.
407        let mut top_x = vec![vec![]; self.experts.len()];
408        let mut selected_rws = vec![vec![]; self.experts.len()];
409        for (row_idx, rw) in routing_weights.iter().enumerate() {
410            let mut dst = (0..rw.len() as u32).collect::<Vec<u32>>();
411            dst.sort_by(|&i, &j| rw[j as usize].total_cmp(&rw[i as usize]));
412            let mut sum_routing_weights = 0f32;
413            for &expert_idx in dst.iter().take(self.num_experts_per_tok) {
414                let expert_idx = expert_idx as usize;
415                let routing_weight = rw[expert_idx];
416                sum_routing_weights += routing_weight;
417                top_x[expert_idx].push(row_idx as u32);
418            }
419            for &expert_idx in dst.iter().take(self.num_experts_per_tok) {
420                let expert_idx = expert_idx as usize;
421                let routing_weight = rw[expert_idx];
422                selected_rws[expert_idx].push(routing_weight / sum_routing_weights)
423            }
424        }
425
426        // routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
427        // expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
428
429        let mut ys = xs.zeros_like()?;
430        for (expert_idx, expert_layer) in self.experts.iter().enumerate() {
431            let top_x = &top_x[expert_idx];
432            if top_x.is_empty() {
433                continue;
434            }
435            let top_x = Tensor::new(top_x.as_slice(), xs.device())?;
436            let selected_rws =
437                Tensor::new(selected_rws[expert_idx].as_slice(), xs.device())?.reshape(((), 1))?;
438            // Index the correct hidden states and compute the expert hidden state for
439            // the current expert. We need to make sure to multiply the output hidden
440            // states by `routing_weights` on the corresponding tokens (top-1 and top-2)
441            let current_state = xs.index_select(&top_x, 0)?.reshape(((), hidden_dim))?;
442            // current_hidden_states = expert_layer(current_state, routing_weights[top_x_list, idx_list, None])
443            let current_hidden_states = expert_layer.forward(
444                &current_state,
445                scalings.clone(),
446                global_scaling_weight,
447                is_scaling_pass,
448            )?;
449            let current_hidden_states = current_hidden_states.broadcast_mul(&selected_rws)?;
450            ys = ys.index_add(&top_x, &current_hidden_states, 0)?;
451        }
452
453        let ys = ys.reshape((b_size, seq_len, hidden_dim))?;
454        Ok(ys)
455    }
456}
457
458struct DecoderLayer {
459    self_attn: Attention,
460    block_sparse_moe: SparseMoeBlock,
461    input_layernorm: RmsNorm,
462    post_attention_layernorm: RmsNorm,
463}
464
465impl DecoderLayer {
466    #[allow(clippy::too_many_arguments)]
467    fn new(
468        rotary_emb: Arc<RotaryEmbedding>,
469        cfg: &Config,
470        vb: ShardedVarBuilder,
471        lora_config: &[((String, String), LoraConfig)],
472        count: &mut usize,
473        ord: &Ordering,
474        mapper: &dyn DeviceMapper,
475        layer_idx: usize,
476        loading_isq: bool,
477        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
478    ) -> Result<Self> {
479        let self_attn = Attention::new(
480            rotary_emb,
481            cfg,
482            vb.pp("self_attn"),
483            lora_config,
484            count,
485            ord,
486            mapper,
487            layer_idx,
488            loading_isq,
489            preload_adapters,
490        )?;
491        let block_sparse_moe = SparseMoeBlock::new(
492            cfg,
493            vb.pp("block_sparse_moe"),
494            lora_config,
495            count,
496            ord,
497            mapper,
498            layer_idx,
499            loading_isq,
500            preload_adapters,
501        )?;
502        let input_layernorm = RmsNorm::new(
503            cfg.hidden_size,
504            cfg.rms_norm_eps,
505            mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
506        )?;
507        let post_attention_layernorm = RmsNorm::new(
508            cfg.hidden_size,
509            cfg.rms_norm_eps,
510            mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
511        )?;
512        Ok(Self {
513            self_attn,
514            block_sparse_moe,
515            input_layernorm,
516            post_attention_layernorm,
517        })
518    }
519
520    #[allow(clippy::too_many_arguments)]
521    fn forward(
522        &self,
523        xs: &Tensor,
524        attention_mask: Option<&Tensor>,
525        seqlen_offsets: &[usize],
526        kv_cache: &mut Option<(Tensor, Tensor)>,
527        scalings: Option<Tensor>,
528        global_scaling_weight: f64,
529        is_scaling_pass: Option<f64>,
530        flash_params: &FlashParams,
531    ) -> Result<Tensor> {
532        let residual = xs;
533        let xs = self.input_layernorm.forward(xs)?;
534        let xs = self.self_attn.forward(
535            &xs,
536            attention_mask,
537            seqlen_offsets,
538            kv_cache,
539            scalings.clone(),
540            global_scaling_weight,
541            is_scaling_pass,
542            flash_params,
543        )?;
544        let xs = (xs + residual)?;
545        let residual = &xs;
546        let xs = self
547            .block_sparse_moe
548            .forward(
549                &xs.apply(&self.post_attention_layernorm)?,
550                scalings.clone(),
551                global_scaling_weight,
552                is_scaling_pass,
553            )?
554            .to_dtype(residual.dtype())?;
555        residual + xs
556    }
557}
558
559pub struct XLoraModel {
560    embed_tokens: candle_nn::Embedding,
561    layers: Vec<DecoderLayer>,
562    norm: RmsNorm,
563    lm_head: Arc<dyn LinearLayerLike + Send + Sync>,
564    sliding_window: Option<usize>,
565    device: Device,
566    cache: EitherCache,
567    dtype: DType,
568    max_seq_len: usize,
569    xlora_classifier: Option<XLoraClassifier>,
570    mapper: Box<dyn DeviceMapper + Send + Sync>,
571    cfg: ModelConfigMetadata,
572}
573
574impl XLoraModel {
575    #[allow(clippy::too_many_arguments)]
576    pub fn new(
577        cfg: &Config,
578        vb: ShardedVarBuilder,
579        lora_config: &[((String, String), LoraConfig)],
580        xlora_config: Option<XLoraConfig>,
581        xlora_ordering: Ordering,
582        is_gptx: bool,
583        normal_loading_metadata: NormalLoadingMetadata,
584        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
585    ) -> Result<Self> {
586        if let Some(ref quant_cfg) = &cfg.quantization_config {
587            tracing::info!(
588                "Using {} quantization: {}.",
589                quant_cfg.name(),
590                quant_cfg.get_bits_name(&vb)
591            );
592        }
593        let mapper = normal_loading_metadata.mapper;
594        let vb_m = vb.pp("model");
595
596        let embed_tokens = layers::embedding(
597            cfg.vocab_size,
598            cfg.hidden_size,
599            mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
600            &cfg.quantization_config,
601        )?;
602        let head_dim = cfg.hidden_size / cfg.num_attention_heads;
603        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
604        let vb_l = vb_m.pp("layers");
605        let mut ropes = HashMap::new();
606        for layer_idx in 0..cfg.num_hidden_layers {
607            let device = mapper
608                .device_for(layer_idx, false)
609                .unwrap_or(&normal_loading_metadata.real_device);
610            ropes.insert(
611                device.location(),
612                Arc::new(RotaryEmbedding::new(
613                    cfg.rope_theta as f32,
614                    head_dim,
615                    cfg.max_position_embeddings,
616                    device,
617                    is_gptx,
618                    vb_m.dtype(),
619                )?),
620            );
621        }
622
623        let mut count = 0;
624        for layer_idx in NiceProgressBar::<_, 'b'>(
625            0..cfg.num_hidden_layers,
626            "Loading repeating layers",
627            &normal_loading_metadata.multi_progress,
628        ) {
629            let device = mapper
630                .device_for(layer_idx, false)
631                .unwrap_or(&normal_loading_metadata.real_device);
632            let rotary_emb = ropes
633                .get(&device.location())
634                .expect("No RoPE for device location!")
635                .clone();
636            let layer = DecoderLayer::new(
637                rotary_emb.clone(),
638                cfg,
639                vb_l.pp(layer_idx),
640                lora_config,
641                &mut count,
642                &xlora_ordering,
643                &*mapper,
644                layer_idx,
645                normal_loading_metadata.loading_isq,
646                preload_adapters,
647            )?;
648            layers.push(layer)
649        }
650        if xlora_config.is_none() && preload_adapters.is_none() {
651            // We are now a LoRA model so we must merge the weights
652            info!("Merging LoRA adapters.");
653            for layer in layers.iter_mut().tqdm() {
654                Arc::get_mut(&mut layer.self_attn.k_proj)
655                    .unwrap()
656                    .merge_weights()?;
657                Arc::get_mut(&mut layer.self_attn.o_proj)
658                    .unwrap()
659                    .merge_weights()?;
660                Arc::get_mut(&mut layer.self_attn.q_proj)
661                    .unwrap()
662                    .merge_weights()?;
663                Arc::get_mut(&mut layer.self_attn.v_proj)
664                    .unwrap()
665                    .merge_weights()?;
666
667                Arc::get_mut(&mut layer.block_sparse_moe.gate)
668                    .unwrap()
669                    .merge_weights()?;
670                for expert in layer.block_sparse_moe.experts.iter_mut() {
671                    Arc::get_mut(&mut expert.w1).unwrap().merge_weights()?;
672                    Arc::get_mut(&mut expert.w2).unwrap().merge_weights()?;
673                    Arc::get_mut(&mut expert.w3).unwrap().merge_weights()?;
674                }
675            }
676        }
677        let norm = RmsNorm::new(
678            cfg.hidden_size,
679            cfg.rms_norm_eps,
680            mapper.set_nm_device(vb_m.pp("norm"), false),
681        )?;
682        let lm_head = linear_no_bias(
683            cfg.hidden_size,
684            cfg.vocab_size,
685            mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
686            mapper.set_nm_device(vb.pp("lm_head"), false),
687            lora_config,
688            &mut count,
689            &xlora_ordering,
690            preload_adapters,
691        )?;
692        if xlora_config.is_some() && lm_head.is_lora() {
693            // This is why we can pass dummy values (..., None, 1.0, None)?
694            candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA.");
695        }
696        Ok(Self {
697            embed_tokens,
698            layers,
699            norm,
700            lm_head,
701            sliding_window: cfg.sliding_window,
702            device: normal_loading_metadata.real_device,
703            dtype: vb.dtype(),
704            cache: EitherCache::Full(Cache::new(cfg.num_hidden_layers, false)),
705            max_seq_len: cfg.max_position_embeddings,
706            xlora_classifier: xlora_config.map(|xlora_config| {
707                XLoraClassifier::new(xlora_config, count, lora_config.len(), vb, false).unwrap()
708            }),
709            mapper,
710            cfg: ModelConfigMetadata {
711                max_seq_len: cfg.max_position_embeddings,
712                num_layers: cfg.num_hidden_layers,
713                hidden_size: cfg.hidden_size,
714                num_kv_heads: cfg.num_key_value_heads,
715                num_attn_heads: cfg.num_attention_heads,
716                sliding_window: cfg.sliding_window,
717                k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
718                v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
719            },
720        })
721    }
722
723    #[allow(clippy::too_many_arguments)]
724    fn inner_forward(
725        &self,
726        input_ids: &Tensor,
727        seqlen_offsets: &[usize],
728        scalings: Option<Tensor>,
729        is_full_pass: bool,
730        no_kv_cache: bool,
731        is_scaling_pass: Option<f64>,
732        flash_params: &FlashParams,
733    ) -> Result<Tensor> {
734        let mut cache = if is_full_pass {
735            if no_kv_cache {
736                let mut new_cache = Vec::new();
737                for _ in 0..self.cache.full().xlora_lock().len() {
738                    new_cache.push(None);
739                }
740
741                self.cache.full().xlora_lock().clone_from(&new_cache);
742            }
743            self.cache.full().xlora_lock()
744        } else {
745            self.cache.full().lock()
746        };
747        let mut xs = self.embed_tokens.forward(input_ids)?;
748        let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
749            input_ids,
750            &*cache,
751            self.sliding_window,
752            xs.dtype(),
753            self.cfg.num_attn_heads,
754        )?;
755        for (i, layer) in self.layers.iter().enumerate() {
756            xs = self.mapper.map(xs, i)?;
757            xs = layer.forward(
758                &xs,
759                attention_mask
760                    .as_ref()
761                    .map(|m| m.to_device(xs.device()).unwrap())
762                    .as_ref(),
763                seqlen_offsets,
764                &mut cache[i],
765                scalings.clone(),
766                self.xlora_classifier
767                    .as_ref()
768                    .map(|classifier| classifier.get_global_scaling_weight())
769                    .unwrap_or(1.0),
770                is_scaling_pass,
771                flash_params,
772            )?
773        }
774        let xs = xs.to_device(&self.device)?;
775        xs.apply(&self.norm)
776    }
777
778    #[allow(clippy::too_many_arguments)]
779    pub fn forward(
780        &self,
781        input_ids: &Tensor,
782        input_ids_full: &Tensor,
783        seqlen_offsets: &[usize],
784        seqlen_offsets_full: &[usize],
785        no_kv_cache: bool,
786        non_granular_state: &Option<NonGranularState>,
787        context_lens: Vec<(usize, usize)>,
788        flash_params: &FlashParams,
789        flash_params_full: &FlashParams,
790    ) -> Result<Tensor> {
791        if self.xlora_classifier.is_some() {
792            let scalings = self.get_scalings(
793                input_ids,
794                input_ids_full,
795                seqlen_offsets,
796                seqlen_offsets_full,
797                no_kv_cache,
798                non_granular_state,
799                &vec![usize::MAX; context_lens.len()],
800                flash_params,
801                flash_params_full,
802            )?;
803
804            if no_kv_cache {
805                let mut res = self
806                    .inner_forward(
807                        input_ids_full,
808                        seqlen_offsets_full,
809                        Some(scalings),
810                        true,
811                        no_kv_cache,
812                        None,
813                        flash_params_full,
814                    )?
815                    .contiguous()?;
816                if let Some(t) = self.lm_head.quantized_act_type() {
817                    res = res.to_dtype(t)?;
818                }
819                extract_logits(
820                    &self.lm_head.lora_forward(&res, None, 1.0, None)?,
821                    context_lens,
822                )
823            } else {
824                // is_full_pass=true is ok because no_kv_cache=false
825                let mut res = self
826                    .inner_forward(
827                        input_ids,
828                        seqlen_offsets,
829                        Some(scalings),
830                        true,
831                        no_kv_cache,
832                        None,
833                        flash_params,
834                    )?
835                    .contiguous()?;
836                if let Some(t) = self.lm_head.quantized_act_type() {
837                    res = res.to_dtype(t)?;
838                }
839                extract_logits(
840                    &self.lm_head.lora_forward(&res, None, 1.0, None)?,
841                    context_lens,
842                )
843            }
844        } else {
845            let mut res = self
846                .inner_forward(
847                    input_ids,
848                    seqlen_offsets,
849                    None,
850                    false,
851                    no_kv_cache,
852                    None,
853                    flash_params,
854                )?
855                .contiguous()?;
856            if let Some(t) = self.lm_head.quantized_act_type() {
857                res = res.to_dtype(t)?;
858            }
859            extract_logits(
860                &self.lm_head.lora_forward(&res, None, 1.0, None)?,
861                context_lens,
862            )
863        }
864    }
865}
866
867impl IsqModel for XLoraModel {
868    fn get_layers(
869        &mut self,
870    ) -> (
871        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
872        &dyn DeviceMapper,
873    ) {
874        let mut tensors = Vec::new();
875        tensors.push((Arc::get_mut(&mut self.lm_head).unwrap().quant_inner(), None));
876        for (i, layer) in self.layers.iter_mut().enumerate() {
877            tensors.push((
878                Arc::get_mut(&mut layer.self_attn.q_proj)
879                    .unwrap()
880                    .quant_inner(),
881                Some(i),
882            ));
883            tensors.push((
884                Arc::get_mut(&mut layer.self_attn.k_proj)
885                    .unwrap()
886                    .quant_inner(),
887                Some(i),
888            ));
889            tensors.push((
890                Arc::get_mut(&mut layer.self_attn.v_proj)
891                    .unwrap()
892                    .quant_inner(),
893                Some(i),
894            ));
895            tensors.push((
896                Arc::get_mut(&mut layer.self_attn.o_proj)
897                    .unwrap()
898                    .quant_inner(),
899                Some(i),
900            ));
901            tensors.push((
902                Arc::get_mut(&mut layer.block_sparse_moe.gate)
903                    .unwrap()
904                    .quant_inner(),
905                Some(i),
906            ));
907            for expert in &mut layer.block_sparse_moe.experts {
908                tensors.push((Arc::get_mut(&mut expert.w1).unwrap().quant_inner(), Some(i)));
909                tensors.push((Arc::get_mut(&mut expert.w2).unwrap().quant_inner(), Some(i)));
910                tensors.push((Arc::get_mut(&mut expert.w3).unwrap().quant_inner(), Some(i)));
911            }
912        }
913        (tensors, &*self.mapper)
914    }
915
916    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
917        panic!("Cannot generate UQFF for an adapter model.")
918    }
919}
920
921impl NormalModel for XLoraModel {
922    fn forward(
923        &self,
924        _input_ids: &Tensor,
925        _seqlen_offsets: &[usize],
926        _context_lens: Vec<(usize, usize)>,
927        _position_ids: Vec<usize>,
928        _metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
929        _flash_params: &FlashParams,
930    ) -> Result<Tensor> {
931        unreachable!()
932    }
933    fn xlora_forward(
934        &self,
935        input_ids: &Tensor,
936        input_ids_full: &Tensor,
937        seqlen_offsets: &[usize],
938        seqlen_offsets_full: &[usize],
939        no_kv_cache: bool,
940        non_granular_state: &Option<crate::xlora_models::NonGranularState>,
941        context_lens: Vec<(usize, usize)>,
942        _position_ids: Vec<usize>,
943        flash_params: &FlashParams,
944        flash_params_full: &FlashParams,
945    ) -> Result<Tensor> {
946        self.forward(
947            input_ids,
948            input_ids_full,
949            seqlen_offsets,
950            seqlen_offsets_full,
951            no_kv_cache,
952            non_granular_state,
953            context_lens,
954            flash_params,
955            flash_params_full,
956        )
957    }
958    fn cache(&self) -> &EitherCache {
959        &self.cache
960    }
961    fn cache_mut(&mut self) -> &mut EitherCache {
962        &mut self.cache
963    }
964    fn device(&self) -> &Device {
965        &self.device
966    }
967    fn is_xlora(&self) -> bool {
968        true
969    }
970    fn max_seq_len(&self) -> usize {
971        self.max_seq_len
972    }
973    fn config(&self) -> &ModelConfigMetadata {
974        &self.cfg
975    }
976}
977
978impl ScalingsMaker for XLoraModel {
979    fn dtype(&self) -> DType {
980        self.dtype
981    }
982    fn get_cache(&self) -> &EitherCache {
983        &self.cache
984    }
985    fn get_classifier(&self) -> &XLoraClassifier {
986        self.xlora_classifier.as_ref().unwrap()
987    }
988    fn forward(
989        &self,
990        input_ids: &Tensor,
991        seqlen_offsets: &[usize],
992        scalings: Tensor,
993        is_full_pass: bool,
994        no_kv_cache: bool,
995        is_scaling_pass: Option<f64>,
996        _context_lens: &[usize],
997        flash_params: &FlashParams,
998    ) -> Result<Tensor> {
999        self.inner_forward(
1000            input_ids,
1001            seqlen_offsets,
1002            Some(scalings),
1003            is_full_pass,
1004            no_kv_cache,
1005            is_scaling_pass,
1006            flash_params,
1007        )
1008    }
1009}
1010
1011impl AnyMoeBaseModelMixin for XLoraModel {}