mistralrs_core/xlora_models/
mixtral.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use crate::{
4    amoe::AnyMoeBaseModelMixin,
5    attention::SdpaParams,
6    layers::{self, Activation, RotaryEmbedding, Sdpa},
7    lora::{linear_no_bias, LinearLayerLike, LoraConfig, Ordering},
8    paged_attention::ModelConfigMetadata,
9    pipeline::{
10        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
11        EitherCache, IsqModel, NormalLoadingMetadata,
12    },
13    utils::progress::NiceProgressBar,
14};
15/// Mixtral Model
16/// https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py
17/// https://mistral.ai/news/mixtral-of-experts/
18use candle_core::{DType, Device, Module, Result, Tensor};
19use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
20use std::{collections::HashMap, sync::Arc};
21use tqdm::Iter;
22use tracing::info;
23
24use crate::{
25    device_map::DeviceMapper,
26    layers::{CausalMasker, RmsNorm},
27    models::mixtral::Config,
28    pipeline::{extract_logits, Cache, NormalModel},
29};
30
31use super::{classifier::XLoraClassifier, NonGranularState, ScalingsMaker, XLoraConfig};
32
33struct Attention {
34    q_proj: Arc<dyn LinearLayerLike + Send + Sync>,
35    k_proj: Arc<dyn LinearLayerLike + Send + Sync>,
36    v_proj: Arc<dyn LinearLayerLike + Send + Sync>,
37    o_proj: Arc<dyn LinearLayerLike + Send + Sync>,
38    num_heads: usize,
39    num_kv_heads: usize,
40    head_dim: usize,
41    rotary_emb: Arc<RotaryEmbedding>,
42    sliding_window: Option<usize>,
43    sdpa_params: SdpaParams,
44}
45
46impl Attention {
47    #[allow(clippy::too_many_arguments)]
48    fn new(
49        rotary_emb: Arc<RotaryEmbedding>,
50        cfg: &Config,
51        vb: ShardedVarBuilder,
52        lora_config: &[((String, String), LoraConfig)],
53        count: &mut usize,
54        ord: &Ordering,
55        mapper: &dyn DeviceMapper,
56        layer_idx: usize,
57        loading_isq: bool,
58        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
59    ) -> Result<Self> {
60        let hidden_sz = cfg.hidden_size;
61        let num_heads = cfg.num_attention_heads;
62        let num_kv_heads = cfg.num_key_value_heads;
63        let head_dim = hidden_sz / num_heads;
64        let q_proj = linear_no_bias(
65            hidden_sz,
66            num_heads * head_dim,
67            mapper.set_device(layer_idx, vb.pp("q_proj"), loading_isq),
68            mapper.set_device(layer_idx, vb.pp("q_proj"), false),
69            lora_config,
70            count,
71            ord,
72            preload_adapters,
73        )?;
74        let k_proj = linear_no_bias(
75            hidden_sz,
76            num_kv_heads * head_dim,
77            mapper.set_device(layer_idx, vb.pp("k_proj"), loading_isq),
78            mapper.set_device(layer_idx, vb.pp("k_proj"), false),
79            lora_config,
80            count,
81            ord,
82            preload_adapters,
83        )?;
84        let v_proj = linear_no_bias(
85            hidden_sz,
86            num_kv_heads * head_dim,
87            mapper.set_device(layer_idx, vb.pp("v_proj"), loading_isq),
88            mapper.set_device(layer_idx, vb.pp("v_proj"), false),
89            lora_config,
90            count,
91            ord,
92            preload_adapters,
93        )?;
94        let o_proj = linear_no_bias(
95            num_heads * head_dim,
96            hidden_sz,
97            mapper.set_device(layer_idx, vb.pp("o_proj"), loading_isq),
98            mapper.set_device(layer_idx, vb.pp("o_proj"), false),
99            lora_config,
100            count,
101            ord,
102            preload_adapters,
103        )?;
104        Ok(Self {
105            q_proj,
106            k_proj,
107            v_proj,
108            o_proj,
109            num_heads,
110            num_kv_heads,
111            head_dim,
112            rotary_emb,
113            sliding_window: cfg.sliding_window,
114            sdpa_params: SdpaParams {
115                n_kv_groups: num_heads / num_kv_heads,
116                softcap: None,
117                softmax_scale: 1.0 / (head_dim as f32).sqrt(),
118                sliding_window: cfg.sliding_window,
119            },
120        })
121    }
122
123    #[allow(clippy::too_many_arguments)]
124    fn forward(
125        &self,
126        xs: &Tensor,
127        attention_mask: Option<&Tensor>,
128        seqlen_offsets: &[usize],
129        kv_cache: &mut Option<(Tensor, Tensor)>,
130        scalings: Option<Tensor>,
131        global_scaling_weight: f64,
132        is_scaling_pass: Option<f64>,
133        flash_params: &FlashParams,
134    ) -> Result<Tensor> {
135        let (b_sz, q_len, _) = xs.dims3()?;
136
137        let original_dtype = xs.dtype();
138        let mut xs = xs.clone();
139        if let Some(t) = self.q_proj.quantized_act_type() {
140            xs = xs.to_dtype(t)?;
141        }
142        let mut q = self.q_proj.lora_forward(
143            &xs,
144            scalings.clone(),
145            global_scaling_weight,
146            is_scaling_pass,
147        )?;
148        let mut k = self.k_proj.lora_forward(
149            &xs,
150            scalings.clone(),
151            global_scaling_weight,
152            is_scaling_pass,
153        )?;
154        let mut v = self.v_proj.lora_forward(
155            &xs,
156            scalings.clone(),
157            global_scaling_weight,
158            is_scaling_pass,
159        )?;
160        if self.q_proj.quantized_act_type().is_some() {
161            q = q.to_dtype(original_dtype)?;
162            k = k.to_dtype(original_dtype)?;
163            v = v.to_dtype(original_dtype)?;
164        }
165
166        let (q, k, v) = if q_len != 1 {
167            let q = q
168                .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
169                .transpose(1, 2)?;
170            let k = k
171                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
172                .transpose(1, 2)?;
173            let v = v
174                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
175                .transpose(1, 2)?;
176            (q, k, v)
177        } else {
178            let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
179            let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
180            let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
181            (q, k, v)
182        };
183
184        let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
185
186        let (k, v, attn_mask) = Cache::update_kv_cache_sliding_window(
187            kv_cache,
188            k,
189            v,
190            attention_mask,
191            self.sliding_window,
192        )?;
193
194        let mut attn_output = Sdpa.run_attention(
195            &q,
196            &k,
197            &v,
198            attn_mask.as_ref(),
199            Some(flash_params),
200            &self.sdpa_params,
201        )?;
202
203        if let Some(t) = self.q_proj.quantized_act_type() {
204            attn_output = attn_output.to_dtype(t)?;
205        }
206        let mut res = self.o_proj.lora_forward(
207            &attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?,
208            scalings.clone(),
209            global_scaling_weight,
210            is_scaling_pass,
211        )?;
212        if self.q_proj.quantized_act_type().is_some() {
213            res = res.to_dtype(original_dtype)?;
214        }
215        Ok(res)
216    }
217}
218
219#[derive(Clone)]
220struct BlockSparseTop2MLP {
221    w1: Arc<dyn LinearLayerLike + Send + Sync>,
222    w2: Arc<dyn LinearLayerLike + Send + Sync>,
223    w3: Arc<dyn LinearLayerLike + Send + Sync>,
224    act_fn: Activation,
225}
226
227impl BlockSparseTop2MLP {
228    #[allow(clippy::too_many_arguments)]
229    fn new(
230        cfg: &Config,
231        vb: ShardedVarBuilder,
232        lora_config: &[((String, String), LoraConfig)],
233        count: &mut usize,
234        ord: &Ordering,
235        mapper: &dyn DeviceMapper,
236        layer_idx: usize,
237        loading_isq: bool,
238        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
239    ) -> Result<Self> {
240        let hidden_sz = cfg.hidden_size;
241        let intermediate_sz = cfg.intermediate_size;
242        let w1 = linear_no_bias(
243            hidden_sz,
244            intermediate_sz,
245            mapper.set_device(layer_idx, vb.pp("w1"), loading_isq),
246            mapper.set_device(layer_idx, vb.pp("w1"), false),
247            lora_config,
248            count,
249            ord,
250            preload_adapters,
251        )?;
252        let w2 = linear_no_bias(
253            intermediate_sz,
254            hidden_sz,
255            mapper.set_device(layer_idx, vb.pp("w2"), loading_isq),
256            mapper.set_device(layer_idx, vb.pp("w2"), false),
257            lora_config,
258            count,
259            ord,
260            preload_adapters,
261        )?;
262        let w3 = linear_no_bias(
263            hidden_sz,
264            intermediate_sz,
265            mapper.set_device(layer_idx, vb.pp("w3"), loading_isq),
266            mapper.set_device(layer_idx, vb.pp("w3"), false),
267            lora_config,
268            count,
269            ord,
270            preload_adapters,
271        )?;
272        Ok(Self {
273            w1,
274            w2,
275            w3,
276            act_fn: cfg.hidden_act,
277        })
278    }
279
280    fn forward(
281        &self,
282        xs: &Tensor,
283        scalings: Option<Tensor>,
284        global_scaling_weight: f64,
285        is_scaling_pass: Option<f64>,
286    ) -> Result<Tensor> {
287        let original_dtype = xs.dtype();
288        let mut xs = xs.clone();
289        if let Some(t) = self.w1.quantized_act_type() {
290            xs = xs.to_dtype(t)?;
291        }
292        let lhs = self
293            .w1
294            .lora_forward(
295                &xs,
296                scalings.clone(),
297                global_scaling_weight,
298                is_scaling_pass,
299            )?
300            .apply(&self.act_fn)?;
301        let rhs = self.w3.lora_forward(
302            &xs,
303            scalings.clone(),
304            global_scaling_weight,
305            is_scaling_pass,
306        )?;
307        let mut res = self.w2.lora_forward(
308            &(lhs * rhs)?,
309            scalings.clone(),
310            global_scaling_weight,
311            is_scaling_pass,
312        )?;
313        if self.w1.quantized_act_type().is_some() {
314            res = res.to_dtype(original_dtype)?;
315        }
316        Ok(res)
317    }
318}
319
320#[derive(Clone)]
321struct SparseMoeBlock {
322    gate: Arc<dyn LinearLayerLike + Send + Sync>,
323    experts: Vec<BlockSparseTop2MLP>,
324    num_experts_per_tok: usize,
325}
326
327impl SparseMoeBlock {
328    #[allow(clippy::too_many_arguments)]
329    fn new(
330        cfg: &Config,
331        vb: ShardedVarBuilder,
332        lora_config: &[((String, String), LoraConfig)],
333        count: &mut usize,
334        ord: &Ordering,
335        mapper: &dyn DeviceMapper,
336        layer_idx: usize,
337        loading_isq: bool,
338        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
339    ) -> Result<Self> {
340        let gate = linear_no_bias(
341            cfg.hidden_size,
342            cfg.num_local_experts,
343            mapper.set_device(layer_idx, vb.pp("gate"), loading_isq),
344            mapper.set_device(layer_idx, vb.pp("gate"), false),
345            lora_config,
346            count,
347            ord,
348            preload_adapters,
349        )?;
350        let mut experts = Vec::with_capacity(cfg.num_local_experts);
351        let vb = vb.pp("experts");
352        for idx in 0..cfg.num_local_experts {
353            let expert = BlockSparseTop2MLP::new(
354                cfg,
355                vb.pp(idx),
356                lora_config,
357                count,
358                ord,
359                mapper,
360                layer_idx,
361                loading_isq,
362                preload_adapters,
363            )?;
364            experts.push(expert)
365        }
366        Ok(SparseMoeBlock {
367            gate,
368            experts,
369            num_experts_per_tok: cfg.num_experts_per_tok,
370        })
371    }
372
373    fn forward(
374        &self,
375        xs: &Tensor,
376        scalings: Option<Tensor>,
377        global_scaling_weight: f64,
378        is_scaling_pass: Option<f64>,
379    ) -> Result<Tensor> {
380        let (b_size, seq_len, hidden_dim) = xs.dims3()?;
381        let xs = xs.reshape(((), hidden_dim))?;
382
383        let original_dtype = xs.dtype();
384        let mut xs = xs.clone();
385        if let Some(t) = self.gate.quantized_act_type() {
386            xs = xs.to_dtype(t)?;
387        }
388        let mut router_logits = self.gate.lora_forward(
389            &xs,
390            scalings.clone(),
391            global_scaling_weight,
392            is_scaling_pass,
393        )?;
394        if self.gate.quantized_act_type().is_some() {
395            router_logits = router_logits.to_dtype(original_dtype)?;
396        }
397
398        let routing_weights = candle_nn::ops::softmax_last_dim(&router_logits)?;
399
400        // In order to extract topk, we extract the data from the tensor and manipulate it
401        // directly. Maybe we will want to use some custom ops instead at some point.
402        let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::<f32>()?;
403
404        // routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
405        // top_x contains the row indexes to evaluate for each expert.
406        let mut top_x = vec![vec![]; self.experts.len()];
407        let mut selected_rws = vec![vec![]; self.experts.len()];
408        for (row_idx, rw) in routing_weights.iter().enumerate() {
409            let mut dst = (0..rw.len() as u32).collect::<Vec<u32>>();
410            dst.sort_by(|&i, &j| rw[j as usize].total_cmp(&rw[i as usize]));
411            let mut sum_routing_weights = 0f32;
412            for &expert_idx in dst.iter().take(self.num_experts_per_tok) {
413                let expert_idx = expert_idx as usize;
414                let routing_weight = rw[expert_idx];
415                sum_routing_weights += routing_weight;
416                top_x[expert_idx].push(row_idx as u32);
417            }
418            for &expert_idx in dst.iter().take(self.num_experts_per_tok) {
419                let expert_idx = expert_idx as usize;
420                let routing_weight = rw[expert_idx];
421                selected_rws[expert_idx].push(routing_weight / sum_routing_weights)
422            }
423        }
424
425        // routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
426        // expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
427
428        let mut ys = xs.zeros_like()?;
429        for (expert_idx, expert_layer) in self.experts.iter().enumerate() {
430            let top_x = &top_x[expert_idx];
431            if top_x.is_empty() {
432                continue;
433            }
434            let top_x = Tensor::new(top_x.as_slice(), xs.device())?;
435            let selected_rws =
436                Tensor::new(selected_rws[expert_idx].as_slice(), xs.device())?.reshape(((), 1))?;
437            // Index the correct hidden states and compute the expert hidden state for
438            // the current expert. We need to make sure to multiply the output hidden
439            // states by `routing_weights` on the corresponding tokens (top-1 and top-2)
440            let current_state = xs.index_select(&top_x, 0)?.reshape(((), hidden_dim))?;
441            // current_hidden_states = expert_layer(current_state, routing_weights[top_x_list, idx_list, None])
442            let current_hidden_states = expert_layer.forward(
443                &current_state,
444                scalings.clone(),
445                global_scaling_weight,
446                is_scaling_pass,
447            )?;
448            let current_hidden_states = current_hidden_states.broadcast_mul(&selected_rws)?;
449            ys = ys.index_add(&top_x, &current_hidden_states, 0)?;
450        }
451
452        let ys = ys.reshape((b_size, seq_len, hidden_dim))?;
453        Ok(ys)
454    }
455}
456
457struct DecoderLayer {
458    self_attn: Attention,
459    block_sparse_moe: SparseMoeBlock,
460    input_layernorm: RmsNorm,
461    post_attention_layernorm: RmsNorm,
462}
463
464impl DecoderLayer {
465    #[allow(clippy::too_many_arguments)]
466    fn new(
467        rotary_emb: Arc<RotaryEmbedding>,
468        cfg: &Config,
469        vb: ShardedVarBuilder,
470        lora_config: &[((String, String), LoraConfig)],
471        count: &mut usize,
472        ord: &Ordering,
473        mapper: &dyn DeviceMapper,
474        layer_idx: usize,
475        loading_isq: bool,
476        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
477    ) -> Result<Self> {
478        let self_attn = Attention::new(
479            rotary_emb,
480            cfg,
481            vb.pp("self_attn"),
482            lora_config,
483            count,
484            ord,
485            mapper,
486            layer_idx,
487            loading_isq,
488            preload_adapters,
489        )?;
490        let block_sparse_moe = SparseMoeBlock::new(
491            cfg,
492            vb.pp("block_sparse_moe"),
493            lora_config,
494            count,
495            ord,
496            mapper,
497            layer_idx,
498            loading_isq,
499            preload_adapters,
500        )?;
501        let input_layernorm = RmsNorm::new(
502            cfg.hidden_size,
503            cfg.rms_norm_eps,
504            mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
505        )?;
506        let post_attention_layernorm = RmsNorm::new(
507            cfg.hidden_size,
508            cfg.rms_norm_eps,
509            mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
510        )?;
511        Ok(Self {
512            self_attn,
513            block_sparse_moe,
514            input_layernorm,
515            post_attention_layernorm,
516        })
517    }
518
519    #[allow(clippy::too_many_arguments)]
520    fn forward(
521        &self,
522        xs: &Tensor,
523        attention_mask: Option<&Tensor>,
524        seqlen_offsets: &[usize],
525        kv_cache: &mut Option<(Tensor, Tensor)>,
526        scalings: Option<Tensor>,
527        global_scaling_weight: f64,
528        is_scaling_pass: Option<f64>,
529        flash_params: &FlashParams,
530    ) -> Result<Tensor> {
531        let residual = xs;
532        let xs = self.input_layernorm.forward(xs)?;
533        let xs = self.self_attn.forward(
534            &xs,
535            attention_mask,
536            seqlen_offsets,
537            kv_cache,
538            scalings.clone(),
539            global_scaling_weight,
540            is_scaling_pass,
541            flash_params,
542        )?;
543        let xs = (xs + residual)?;
544        let residual = &xs;
545        let xs = self
546            .block_sparse_moe
547            .forward(
548                &xs.apply(&self.post_attention_layernorm)?,
549                scalings.clone(),
550                global_scaling_weight,
551                is_scaling_pass,
552            )?
553            .to_dtype(residual.dtype())?;
554        residual + xs
555    }
556}
557
558pub struct XLoraModel {
559    embed_tokens: candle_nn::Embedding,
560    layers: Vec<DecoderLayer>,
561    norm: RmsNorm,
562    lm_head: Arc<dyn LinearLayerLike + Send + Sync>,
563    sliding_window: Option<usize>,
564    device: Device,
565    cache: EitherCache,
566    dtype: DType,
567    max_seq_len: usize,
568    xlora_classifier: Option<XLoraClassifier>,
569    mapper: Box<dyn DeviceMapper + Send + Sync>,
570    cfg: ModelConfigMetadata,
571}
572
573impl XLoraModel {
574    #[allow(clippy::too_many_arguments)]
575    pub fn new(
576        cfg: &Config,
577        vb: ShardedVarBuilder,
578        lora_config: &[((String, String), LoraConfig)],
579        xlora_config: Option<XLoraConfig>,
580        xlora_ordering: Ordering,
581        is_gptx: bool,
582        normal_loading_metadata: NormalLoadingMetadata,
583        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
584    ) -> Result<Self> {
585        if let Some(ref quant_cfg) = &cfg.quantization_config {
586            tracing::info!(
587                "Using {} quantization: {}.",
588                quant_cfg.name(),
589                quant_cfg.get_bits_name(&vb)
590            );
591        }
592        let mapper = normal_loading_metadata.mapper;
593        let vb_m = vb.pp("model");
594
595        let embed_tokens = layers::embedding(
596            cfg.vocab_size,
597            cfg.hidden_size,
598            mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
599            &cfg.quantization_config,
600        )?;
601        let head_dim = cfg.hidden_size / cfg.num_attention_heads;
602        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
603        let vb_l = vb_m.pp("layers");
604        let mut ropes = HashMap::new();
605        for layer_idx in 0..cfg.num_hidden_layers {
606            let device = mapper
607                .device_for(layer_idx, false)
608                .unwrap_or(&normal_loading_metadata.real_device);
609            ropes.insert(
610                device.location(),
611                Arc::new(RotaryEmbedding::new(
612                    cfg.rope_theta as f32,
613                    head_dim,
614                    cfg.max_position_embeddings,
615                    device,
616                    is_gptx,
617                    vb_m.dtype(),
618                )?),
619            );
620        }
621
622        let mut count = 0;
623        for layer_idx in NiceProgressBar::<_, 'b'>(
624            0..cfg.num_hidden_layers,
625            "Loading repeating layers",
626            &normal_loading_metadata.multi_progress,
627        ) {
628            let device = mapper
629                .device_for(layer_idx, false)
630                .unwrap_or(&normal_loading_metadata.real_device);
631            let rotary_emb = ropes
632                .get(&device.location())
633                .expect("No RoPE for device location!")
634                .clone();
635            let layer = DecoderLayer::new(
636                rotary_emb.clone(),
637                cfg,
638                vb_l.pp(layer_idx),
639                lora_config,
640                &mut count,
641                &xlora_ordering,
642                &*mapper,
643                layer_idx,
644                normal_loading_metadata.loading_isq,
645                preload_adapters,
646            )?;
647            layers.push(layer)
648        }
649        if xlora_config.is_none() && preload_adapters.is_none() {
650            // We are now a LoRA model so we must merge the weights
651            info!("Merging LoRA adapters.");
652            for layer in layers.iter_mut().tqdm() {
653                Arc::get_mut(&mut layer.self_attn.k_proj)
654                    .unwrap()
655                    .merge_weights()?;
656                Arc::get_mut(&mut layer.self_attn.o_proj)
657                    .unwrap()
658                    .merge_weights()?;
659                Arc::get_mut(&mut layer.self_attn.q_proj)
660                    .unwrap()
661                    .merge_weights()?;
662                Arc::get_mut(&mut layer.self_attn.v_proj)
663                    .unwrap()
664                    .merge_weights()?;
665
666                Arc::get_mut(&mut layer.block_sparse_moe.gate)
667                    .unwrap()
668                    .merge_weights()?;
669                for expert in layer.block_sparse_moe.experts.iter_mut() {
670                    Arc::get_mut(&mut expert.w1).unwrap().merge_weights()?;
671                    Arc::get_mut(&mut expert.w2).unwrap().merge_weights()?;
672                    Arc::get_mut(&mut expert.w3).unwrap().merge_weights()?;
673                }
674            }
675        }
676        let norm = RmsNorm::new(
677            cfg.hidden_size,
678            cfg.rms_norm_eps,
679            mapper.set_nm_device(vb_m.pp("norm"), false),
680        )?;
681        let lm_head = linear_no_bias(
682            cfg.hidden_size,
683            cfg.vocab_size,
684            mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
685            mapper.set_nm_device(vb.pp("lm_head"), false),
686            lora_config,
687            &mut count,
688            &xlora_ordering,
689            preload_adapters,
690        )?;
691        if xlora_config.is_some() && lm_head.is_lora() {
692            // This is why we can pass dummy values (..., None, 1.0, None)?
693            candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA.");
694        }
695        Ok(Self {
696            embed_tokens,
697            layers,
698            norm,
699            lm_head,
700            sliding_window: cfg.sliding_window,
701            device: normal_loading_metadata.real_device,
702            dtype: vb.dtype(),
703            cache: EitherCache::Full(Cache::new(cfg.num_hidden_layers, false)),
704            max_seq_len: cfg.max_position_embeddings,
705            xlora_classifier: xlora_config.map(|xlora_config| {
706                XLoraClassifier::new(xlora_config, count, lora_config.len(), vb, false).unwrap()
707            }),
708            mapper,
709            cfg: ModelConfigMetadata {
710                max_seq_len: cfg.max_position_embeddings,
711                num_layers: cfg.num_hidden_layers,
712                hidden_size: cfg.hidden_size,
713                num_kv_heads: cfg.num_key_value_heads,
714                num_attn_heads: cfg.num_attention_heads,
715                sliding_window: cfg.sliding_window,
716                k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
717                v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
718            },
719        })
720    }
721
722    #[allow(clippy::too_many_arguments)]
723    fn inner_forward(
724        &self,
725        input_ids: &Tensor,
726        seqlen_offsets: &[usize],
727        scalings: Option<Tensor>,
728        is_full_pass: bool,
729        no_kv_cache: bool,
730        is_scaling_pass: Option<f64>,
731        flash_params: &FlashParams,
732    ) -> Result<Tensor> {
733        let mut cache = if is_full_pass {
734            if no_kv_cache {
735                let mut new_cache = Vec::new();
736                for _ in 0..self.cache.full().xlora_lock().len() {
737                    new_cache.push(None);
738                }
739
740                self.cache.full().xlora_lock().clone_from(&new_cache);
741            }
742            self.cache.full().xlora_lock()
743        } else {
744            self.cache.full().lock()
745        };
746        let mut xs = self.embed_tokens.forward(input_ids)?;
747        let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
748            input_ids,
749            &*cache,
750            self.sliding_window,
751            xs.dtype(),
752            self.cfg.num_attn_heads,
753        )?;
754        for (i, layer) in self.layers.iter().enumerate() {
755            xs = self.mapper.map(xs, i)?;
756            xs = layer.forward(
757                &xs,
758                attention_mask
759                    .as_ref()
760                    .map(|m| m.to_device(xs.device()).unwrap())
761                    .as_ref(),
762                seqlen_offsets,
763                &mut cache[i],
764                scalings.clone(),
765                self.xlora_classifier
766                    .as_ref()
767                    .map(|classifier| classifier.get_global_scaling_weight())
768                    .unwrap_or(1.0),
769                is_scaling_pass,
770                flash_params,
771            )?
772        }
773        let xs = xs.to_device(&self.device)?;
774        xs.apply(&self.norm)
775    }
776
777    #[allow(clippy::too_many_arguments)]
778    pub fn forward(
779        &self,
780        input_ids: &Tensor,
781        input_ids_full: &Tensor,
782        seqlen_offsets: &[usize],
783        seqlen_offsets_full: &[usize],
784        no_kv_cache: bool,
785        non_granular_state: &Option<NonGranularState>,
786        context_lens: Vec<(usize, usize)>,
787        flash_params: &FlashParams,
788        flash_params_full: &FlashParams,
789    ) -> Result<Tensor> {
790        if self.xlora_classifier.is_some() {
791            let scalings = self.get_scalings(
792                input_ids,
793                input_ids_full,
794                seqlen_offsets,
795                seqlen_offsets_full,
796                no_kv_cache,
797                non_granular_state,
798                &vec![usize::MAX; context_lens.len()],
799                flash_params,
800                flash_params_full,
801            )?;
802
803            if no_kv_cache {
804                let mut res = self
805                    .inner_forward(
806                        input_ids_full,
807                        seqlen_offsets_full,
808                        Some(scalings),
809                        true,
810                        no_kv_cache,
811                        None,
812                        flash_params_full,
813                    )?
814                    .contiguous()?;
815                if let Some(t) = self.lm_head.quantized_act_type() {
816                    res = res.to_dtype(t)?;
817                }
818                extract_logits(
819                    &self.lm_head.lora_forward(&res, None, 1.0, None)?,
820                    context_lens,
821                )
822            } else {
823                // is_full_pass=true is ok because no_kv_cache=false
824                let mut res = self
825                    .inner_forward(
826                        input_ids,
827                        seqlen_offsets,
828                        Some(scalings),
829                        true,
830                        no_kv_cache,
831                        None,
832                        flash_params,
833                    )?
834                    .contiguous()?;
835                if let Some(t) = self.lm_head.quantized_act_type() {
836                    res = res.to_dtype(t)?;
837                }
838                extract_logits(
839                    &self.lm_head.lora_forward(&res, None, 1.0, None)?,
840                    context_lens,
841                )
842            }
843        } else {
844            let mut res = self
845                .inner_forward(
846                    input_ids,
847                    seqlen_offsets,
848                    None,
849                    false,
850                    no_kv_cache,
851                    None,
852                    flash_params,
853                )?
854                .contiguous()?;
855            if let Some(t) = self.lm_head.quantized_act_type() {
856                res = res.to_dtype(t)?;
857            }
858            extract_logits(
859                &self.lm_head.lora_forward(&res, None, 1.0, None)?,
860                context_lens,
861            )
862        }
863    }
864}
865
866impl IsqModel for XLoraModel {
867    fn get_layers(
868        &mut self,
869    ) -> (
870        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
871        &dyn DeviceMapper,
872    ) {
873        let mut tensors = Vec::new();
874        tensors.push((Arc::get_mut(&mut self.lm_head).unwrap().quant_inner(), None));
875        for (i, layer) in self.layers.iter_mut().enumerate() {
876            tensors.push((
877                Arc::get_mut(&mut layer.self_attn.q_proj)
878                    .unwrap()
879                    .quant_inner(),
880                Some(i),
881            ));
882            tensors.push((
883                Arc::get_mut(&mut layer.self_attn.k_proj)
884                    .unwrap()
885                    .quant_inner(),
886                Some(i),
887            ));
888            tensors.push((
889                Arc::get_mut(&mut layer.self_attn.v_proj)
890                    .unwrap()
891                    .quant_inner(),
892                Some(i),
893            ));
894            tensors.push((
895                Arc::get_mut(&mut layer.self_attn.o_proj)
896                    .unwrap()
897                    .quant_inner(),
898                Some(i),
899            ));
900            tensors.push((
901                Arc::get_mut(&mut layer.block_sparse_moe.gate)
902                    .unwrap()
903                    .quant_inner(),
904                Some(i),
905            ));
906            for expert in &mut layer.block_sparse_moe.experts {
907                tensors.push((Arc::get_mut(&mut expert.w1).unwrap().quant_inner(), Some(i)));
908                tensors.push((Arc::get_mut(&mut expert.w2).unwrap().quant_inner(), Some(i)));
909                tensors.push((Arc::get_mut(&mut expert.w3).unwrap().quant_inner(), Some(i)));
910            }
911        }
912        (tensors, &*self.mapper)
913    }
914
915    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
916        panic!("Cannot generate UQFF for an adapter model.")
917    }
918}
919
920impl NormalModel for XLoraModel {
921    fn forward(
922        &self,
923        _input_ids: &Tensor,
924        _seqlen_offsets: &[usize],
925        _context_lens: Vec<(usize, usize)>,
926        _position_ids: Vec<usize>,
927        _metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
928        _flash_params: &FlashParams,
929    ) -> Result<Tensor> {
930        unreachable!()
931    }
932    fn xlora_forward(
933        &self,
934        input_ids: &Tensor,
935        input_ids_full: &Tensor,
936        seqlen_offsets: &[usize],
937        seqlen_offsets_full: &[usize],
938        no_kv_cache: bool,
939        non_granular_state: &Option<crate::xlora_models::NonGranularState>,
940        context_lens: Vec<(usize, usize)>,
941        _position_ids: Vec<usize>,
942        flash_params: &FlashParams,
943        flash_params_full: &FlashParams,
944    ) -> Result<Tensor> {
945        self.forward(
946            input_ids,
947            input_ids_full,
948            seqlen_offsets,
949            seqlen_offsets_full,
950            no_kv_cache,
951            non_granular_state,
952            context_lens,
953            flash_params,
954            flash_params_full,
955        )
956    }
957    fn cache(&self) -> &EitherCache {
958        &self.cache
959    }
960    fn cache_mut(&mut self) -> &mut EitherCache {
961        &mut self.cache
962    }
963    fn device(&self) -> &Device {
964        &self.device
965    }
966    fn is_xlora(&self) -> bool {
967        true
968    }
969    fn max_seq_len(&self) -> usize {
970        self.max_seq_len
971    }
972    fn config(&self) -> &ModelConfigMetadata {
973        &self.cfg
974    }
975}
976
977impl ScalingsMaker for XLoraModel {
978    fn dtype(&self) -> DType {
979        self.dtype
980    }
981    fn get_cache(&self) -> &EitherCache {
982        &self.cache
983    }
984    fn get_classifier(&self) -> &XLoraClassifier {
985        self.xlora_classifier.as_ref().unwrap()
986    }
987    fn forward(
988        &self,
989        input_ids: &Tensor,
990        seqlen_offsets: &[usize],
991        scalings: Tensor,
992        is_full_pass: bool,
993        no_kv_cache: bool,
994        is_scaling_pass: Option<f64>,
995        _context_lens: &[usize],
996        flash_params: &FlashParams,
997    ) -> Result<Tensor> {
998        self.inner_forward(
999            input_ids,
1000            seqlen_offsets,
1001            Some(scalings),
1002            is_full_pass,
1003            no_kv_cache,
1004            is_scaling_pass,
1005            flash_params,
1006        )
1007    }
1008}
1009
1010impl AnyMoeBaseModelMixin for XLoraModel {}