mistralrs_core/xlora_models/
starcoder2.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use candle_core::{DType, Device, Module, Result, Tensor};
4use candle_nn::LayerNorm;
5use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
6use std::{collections::HashMap, sync::Arc};
7use tqdm::Iter;
8use tracing::info;
9
10use crate::{
11    amoe::AnyMoeBaseModelMixin,
12    attention::SdpaParams,
13    device_map::DeviceMapper,
14    layers::{self, layer_norm, Activation, CausalMasker, RotaryEmbedding, Sdpa},
15    lora::{linear_b, linear_no_bias, LinearLayerLike, LoraConfig},
16    models::starcoder2::Config,
17    paged_attention::ModelConfigMetadata,
18    pipeline::{
19        extract_logits,
20        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
21        Cache, EitherCache, IsqModel, NormalLoadingMetadata, NormalModel,
22    },
23    utils::progress::NiceProgressBar,
24    Ordering,
25};
26
27use super::{classifier::XLoraClassifier, NonGranularState, ScalingsMaker, XLoraConfig};
28
29#[derive(Clone)]
30#[allow(clippy::upper_case_acronyms)]
31struct MLP {
32    c_fc: Arc<dyn LinearLayerLike + Send + Sync>,
33    c_proj: Arc<dyn LinearLayerLike + Send + Sync>,
34    act: Activation,
35}
36
37impl MLP {
38    #[allow(clippy::too_many_arguments)]
39    fn new(
40        cfg: &Config,
41        vb: ShardedVarBuilder,
42        lora_config: &[((String, String), LoraConfig)],
43        count: &mut usize,
44        ord: &Ordering,
45        mapper: &dyn DeviceMapper,
46        layer_idx: usize,
47        loading_isq: bool,
48        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
49    ) -> Result<Self> {
50        let (h_size, i_size) = (cfg.hidden_size, cfg.intermediate_size);
51        let c_fc = linear_b(
52            h_size,
53            i_size,
54            cfg.use_bias,
55            mapper.set_device(layer_idx, vb.pp("c_fc"), loading_isq),
56            mapper.set_device(layer_idx, vb.pp("c_fc"), false),
57            lora_config,
58            count,
59            ord,
60            preload_adapters,
61        )?;
62        let c_proj = linear_b(
63            i_size,
64            h_size,
65            cfg.use_bias,
66            mapper.set_device(layer_idx, vb.pp("c_proj"), loading_isq),
67            mapper.set_device(layer_idx, vb.pp("c_proj"), false),
68            lora_config,
69            count,
70            ord,
71            preload_adapters,
72        )?;
73        Ok(Self {
74            c_fc,
75            c_proj,
76            act: cfg.hidden_act,
77        })
78    }
79
80    fn forward(
81        &self,
82        xs: &Tensor,
83        scalings: Option<Tensor>,
84        global_scaling_weight: f64,
85        is_scaling_pass: Option<f64>,
86    ) -> Result<Tensor> {
87        let original_dtype = xs.dtype();
88        let mut xs = xs.clone();
89        if let Some(t) = self.c_fc.quantized_act_type() {
90            xs = xs.to_dtype(t)?;
91        }
92        let mut res = self.c_proj.lora_forward(
93            &self
94                .c_fc
95                .lora_forward(
96                    &xs,
97                    scalings.clone(),
98                    global_scaling_weight,
99                    is_scaling_pass,
100                )?
101                .apply(&self.act)?,
102            scalings,
103            global_scaling_weight,
104            is_scaling_pass,
105        )?;
106        if self.c_fc.quantized_act_type().is_some() {
107            res = res.to_dtype(original_dtype)?;
108        }
109        Ok(res)
110    }
111}
112
113struct Attention {
114    q_proj: Arc<dyn LinearLayerLike + Send + Sync>,
115    k_proj: Arc<dyn LinearLayerLike + Send + Sync>,
116    v_proj: Arc<dyn LinearLayerLike + Send + Sync>,
117    o_proj: Arc<dyn LinearLayerLike + Send + Sync>,
118    num_heads: usize,
119    num_kv_heads: usize,
120    head_dim: usize,
121    hidden_size: usize,
122    rotary_emb: Arc<RotaryEmbedding>,
123    sliding_window: Option<usize>,
124    sdpa_params: SdpaParams,
125}
126
127impl Attention {
128    #[allow(clippy::too_many_arguments)]
129    fn new(
130        rotary_emb: Arc<RotaryEmbedding>,
131        cfg: &Config,
132        vb: ShardedVarBuilder,
133        lora_config: &[((String, String), LoraConfig)],
134        count: &mut usize,
135        ord: &Ordering,
136        mapper: &dyn DeviceMapper,
137        layer_idx: usize,
138        loading_isq: bool,
139        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
140    ) -> Result<Self> {
141        let hidden_sz = cfg.hidden_size;
142        let num_heads = cfg.num_attention_heads;
143        let num_kv_heads = cfg.num_key_value_heads;
144        let head_dim = hidden_sz / num_heads;
145        let b = cfg.use_bias;
146        let q_proj = linear_b(
147            hidden_sz,
148            num_heads * head_dim,
149            b,
150            mapper.set_device(layer_idx, vb.pp("q_proj"), loading_isq),
151            mapper.set_device(layer_idx, vb.pp("q_proj"), false),
152            lora_config,
153            count,
154            ord,
155            preload_adapters,
156        )?;
157        let k_proj = linear_b(
158            hidden_sz,
159            num_kv_heads * head_dim,
160            b,
161            mapper.set_device(layer_idx, vb.pp("k_proj"), loading_isq),
162            mapper.set_device(layer_idx, vb.pp("k_proj"), false),
163            lora_config,
164            count,
165            ord,
166            preload_adapters,
167        )?;
168        let v_proj = linear_b(
169            hidden_sz,
170            num_kv_heads * head_dim,
171            b,
172            mapper.set_device(layer_idx, vb.pp("v_proj"), loading_isq),
173            mapper.set_device(layer_idx, vb.pp("v_proj"), false),
174            lora_config,
175            count,
176            ord,
177            preload_adapters,
178        )?;
179        let o_proj = linear_b(
180            num_heads * head_dim,
181            hidden_sz,
182            b,
183            mapper.set_device(layer_idx, vb.pp("v_proj"), loading_isq),
184            mapper.set_device(layer_idx, vb.pp("v_proj"), false),
185            lora_config,
186            count,
187            ord,
188            preload_adapters,
189        )?;
190        Ok(Self {
191            q_proj,
192            k_proj,
193            v_proj,
194            o_proj,
195            num_heads,
196            num_kv_heads,
197            head_dim,
198            hidden_size: hidden_sz,
199            rotary_emb,
200            sliding_window: cfg.sliding_window,
201            sdpa_params: SdpaParams {
202                n_kv_groups: num_heads / num_kv_heads,
203                use_flash_attn: cfg.use_flash_attn,
204                softcap: None,
205                softmax_scale: 1.0 / (head_dim as f32).sqrt(),
206                sliding_window: cfg.sliding_window,
207            },
208        })
209    }
210
211    #[allow(clippy::too_many_arguments)]
212    fn forward(
213        &self,
214        xs: &Tensor,
215        attention_mask: Option<&Tensor>,
216        seqlen_offsets: &[usize],
217        kv_cache: &mut Option<(Tensor, Tensor)>,
218        scalings: Option<Tensor>,
219        global_scaling_weight: f64,
220        is_scaling_pass: Option<f64>,
221        flash_params: &FlashParams,
222    ) -> Result<Tensor> {
223        let (b_sz, q_len, _) = xs.dims3()?;
224
225        let original_dtype = xs.dtype();
226        let mut xs = xs.clone();
227        if let Some(t) = self.q_proj.quantized_act_type() {
228            xs = xs.to_dtype(t)?;
229        }
230        let mut q = self.q_proj.lora_forward(
231            &xs,
232            scalings.clone(),
233            global_scaling_weight,
234            is_scaling_pass,
235        )?;
236        let mut k = self.k_proj.lora_forward(
237            &xs,
238            scalings.clone(),
239            global_scaling_weight,
240            is_scaling_pass,
241        )?;
242        let mut v = self.v_proj.lora_forward(
243            &xs,
244            scalings.clone(),
245            global_scaling_weight,
246            is_scaling_pass,
247        )?;
248        if self.q_proj.quantized_act_type().is_some() {
249            q = q.to_dtype(original_dtype)?;
250            k = k.to_dtype(original_dtype)?;
251            v = v.to_dtype(original_dtype)?;
252        }
253
254        let (q, k, v) = if q_len != 1 {
255            let q = q
256                .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
257                .transpose(1, 2)?;
258            let k = k
259                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
260                .transpose(1, 2)?;
261            let v = v
262                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
263                .transpose(1, 2)?;
264            (q, k, v)
265        } else {
266            let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
267            let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
268            let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
269            (q, k, v)
270        };
271
272        let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
273
274        let (k, v, attn_mask) = Cache::update_kv_cache_sliding_window(
275            kv_cache,
276            k,
277            v,
278            attention_mask,
279            self.sliding_window,
280            false,
281        )?;
282
283        let mut attn_output = Sdpa.run_attention(
284            &q,
285            &k,
286            &v,
287            attn_mask.as_ref(),
288            Some(flash_params),
289            &self.sdpa_params,
290        )?;
291
292        if let Some(t) = self.q_proj.quantized_act_type() {
293            attn_output = attn_output.to_dtype(t)?;
294        }
295        let mut res = self.o_proj.lora_forward(
296            &attn_output
297                .transpose(1, 2)?
298                .reshape((b_sz, q_len, self.hidden_size))?,
299            scalings.clone(),
300            global_scaling_weight,
301            is_scaling_pass,
302        )?;
303        if self.q_proj.quantized_act_type().is_some() {
304            res = res.to_dtype(original_dtype)?;
305        }
306        Ok(res)
307    }
308}
309
310struct DecoderLayer {
311    self_attn: Attention,
312    mlp: MLP,
313    input_layernorm: LayerNorm,
314    post_attention_layernorm: LayerNorm,
315}
316
317impl DecoderLayer {
318    #[allow(clippy::too_many_arguments)]
319    fn new(
320        rotary_emb: Arc<RotaryEmbedding>,
321        cfg: &Config,
322        vb: ShardedVarBuilder,
323        lora_config: &[((String, String), LoraConfig)],
324        count: &mut usize,
325        ord: &Ordering,
326        mapper: &dyn DeviceMapper,
327        layer_idx: usize,
328        loading_isq: bool,
329        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
330    ) -> Result<Self> {
331        let self_attn = Attention::new(
332            rotary_emb,
333            cfg,
334            vb.pp("self_attn"),
335            lora_config,
336            count,
337            ord,
338            mapper,
339            layer_idx,
340            loading_isq,
341            preload_adapters,
342        )?;
343        let mlp = MLP::new(
344            cfg,
345            vb.pp("mlp"),
346            lora_config,
347            count,
348            ord,
349            mapper,
350            layer_idx,
351            loading_isq,
352            preload_adapters,
353        )?;
354        let input_layernorm = layer_norm(
355            cfg.hidden_size,
356            cfg.norm_epsilon,
357            mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
358        )?;
359        let post_attention_layernorm = layer_norm(
360            cfg.hidden_size,
361            cfg.norm_epsilon,
362            mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
363        )?;
364        Ok(Self {
365            self_attn,
366            mlp,
367            input_layernorm,
368            post_attention_layernorm,
369        })
370    }
371
372    #[allow(clippy::too_many_arguments)]
373    fn forward(
374        &self,
375        xs: &Tensor,
376        attention_mask: Option<&Tensor>,
377        seqlen_offsets: &[usize],
378        kv_cache: &mut Option<(Tensor, Tensor)>,
379        scalings: Option<Tensor>,
380        global_scaling_weight: f64,
381        is_scaling_pass: Option<f64>,
382        flash_params: &FlashParams,
383    ) -> Result<Tensor> {
384        let residual = xs;
385        let xs = self.input_layernorm.forward(xs)?;
386        let xs = self.self_attn.forward(
387            &xs,
388            attention_mask,
389            seqlen_offsets,
390            kv_cache,
391            scalings.clone(),
392            global_scaling_weight,
393            is_scaling_pass,
394            flash_params,
395        )?;
396        let xs = (xs + residual)?;
397        let residual = &xs;
398        let xs = self.mlp.forward(
399            &xs.apply(&self.post_attention_layernorm)?,
400            scalings.clone(),
401            global_scaling_weight,
402            is_scaling_pass,
403        )?;
404        residual + xs
405    }
406}
407
408pub struct Model {
409    embed_tokens: candle_nn::Embedding,
410    layers: Vec<DecoderLayer>,
411    norm: LayerNorm,
412    lm_head: Arc<dyn LinearLayerLike + Send + Sync>,
413    sliding_window: Option<usize>,
414    device: Device,
415    cache: EitherCache,
416    max_seq_len: usize,
417    mapper: Box<dyn DeviceMapper + Send + Sync>,
418    xlora_classifier: Option<XLoraClassifier>,
419    dtype: DType,
420    cfg: ModelConfigMetadata,
421}
422
423impl Model {
424    #[allow(clippy::too_many_arguments)]
425    pub fn new(
426        cfg: &Config,
427        vb: ShardedVarBuilder,
428        lora_config: &[((String, String), LoraConfig)],
429        xlora_config: Option<XLoraConfig>,
430        xlora_ordering: Ordering,
431        is_gptx: bool,
432        normal_loading_metadata: NormalLoadingMetadata,
433        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
434    ) -> Result<Self> {
435        if let Some(ref quant_cfg) = &cfg.quantization_config {
436            tracing::info!(
437                "Using {} quantization: {}.",
438                quant_cfg.quant_method.to_string(),
439                quant_cfg.get_bits_name(&vb)
440            );
441        }
442        let mapper = normal_loading_metadata.mapper;
443        let vb_m = vb.pp("model");
444
445        let embed_tokens = layers::embedding(
446            cfg.vocab_size,
447            cfg.hidden_size,
448            mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
449        )?;
450        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
451        let vb_l = vb_m.pp("layers");
452        let head_dim = cfg.hidden_size / cfg.num_attention_heads;
453        let mut ropes = HashMap::new();
454        for layer_idx in 0..cfg.num_hidden_layers {
455            let device = mapper
456                .device_for(layer_idx, false)
457                .unwrap_or(&normal_loading_metadata.real_device);
458            ropes.insert(
459                device.location(),
460                Arc::new(RotaryEmbedding::new(
461                    cfg.rope_theta as f32,
462                    head_dim,
463                    cfg.max_position_embeddings,
464                    device,
465                    is_gptx,
466                    vb_m.dtype(),
467                )?),
468            );
469        }
470        let mut count = 0;
471        for layer_idx in NiceProgressBar::<_, 'b'>(
472            0..cfg.num_hidden_layers,
473            "Loading repeating layers",
474            &normal_loading_metadata.multi_progress,
475        ) {
476            let device = mapper
477                .device_for(layer_idx, false)
478                .unwrap_or(&normal_loading_metadata.real_device);
479            let rotary_emb = ropes
480                .get(&device.location())
481                .expect("No RoPE for device location!")
482                .clone();
483            layers.push(DecoderLayer::new(
484                rotary_emb.clone(),
485                cfg,
486                vb_l.pp(layer_idx),
487                lora_config,
488                &mut count,
489                &xlora_ordering,
490                &*mapper,
491                layer_idx,
492                normal_loading_metadata.loading_isq,
493                preload_adapters,
494            )?)
495        }
496        if xlora_config.is_none() && preload_adapters.is_none() {
497            // We are now a LoRA model so we must merge the weights
498            info!("Merging LoRA adapters.");
499            for layer in layers.iter_mut().tqdm() {
500                Arc::get_mut(&mut layer.self_attn.k_proj)
501                    .unwrap()
502                    .merge_weights()?;
503                Arc::get_mut(&mut layer.self_attn.o_proj)
504                    .unwrap()
505                    .merge_weights()?;
506                Arc::get_mut(&mut layer.self_attn.q_proj)
507                    .unwrap()
508                    .merge_weights()?;
509                Arc::get_mut(&mut layer.self_attn.v_proj)
510                    .unwrap()
511                    .merge_weights()?;
512
513                Arc::get_mut(&mut layer.mlp.c_fc).unwrap().merge_weights()?;
514                Arc::get_mut(&mut layer.mlp.c_proj)
515                    .unwrap()
516                    .merge_weights()?;
517            }
518        }
519        let norm = layer_norm(
520            cfg.hidden_size,
521            cfg.norm_epsilon,
522            mapper.set_nm_device(vb_m.pp("norm"), false),
523        )?;
524        let lm_head = linear_no_bias(
525            embed_tokens.embeddings().dim(1)?,
526            embed_tokens.embeddings().dim(0)?,
527            mapper.set_nm_device(vb_m.pp("embed_tokens"), normal_loading_metadata.loading_isq),
528            mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
529            lora_config,
530            &mut count,
531            &xlora_ordering,
532            preload_adapters,
533        )?;
534        if xlora_config.is_some() && lm_head.is_lora() {
535            // This is why we can pass dummy values (..., None, 1.0, None)?
536            candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA.");
537        }
538        Ok(Self {
539            embed_tokens,
540            layers,
541            norm,
542            lm_head,
543            sliding_window: cfg.sliding_window,
544            device: normal_loading_metadata.real_device,
545            cache: EitherCache::Full(Cache::new(cfg.num_hidden_layers, true)),
546            max_seq_len: cfg.max_position_embeddings,
547            mapper,
548            dtype: vb.dtype(),
549            xlora_classifier: xlora_config.map(|xlora_config| {
550                XLoraClassifier::new(xlora_config, count, lora_config.len(), vb, false).unwrap()
551            }),
552            cfg: ModelConfigMetadata {
553                max_seq_len: cfg.max_position_embeddings,
554                num_layers: cfg.num_hidden_layers,
555                hidden_size: cfg.hidden_size,
556                num_kv_heads: cfg.num_key_value_heads,
557                num_attn_heads: cfg.num_attention_heads,
558                sliding_window: cfg.sliding_window,
559                k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
560                v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
561            },
562        })
563    }
564
565    #[allow(clippy::too_many_arguments)]
566    fn inner_forward(
567        &self,
568        input_ids: &Tensor,
569        seqlen_offsets: &[usize],
570        scalings: Option<Tensor>,
571        is_full_pass: bool,
572        no_kv_cache: bool,
573        is_scaling_pass: Option<f64>,
574        flash_params: &FlashParams,
575    ) -> Result<Tensor> {
576        let mut xs = self.embed_tokens.forward(input_ids)?;
577
578        let mut cache = if is_full_pass {
579            if no_kv_cache {
580                let mut new_cache = Vec::new();
581                for _ in 0..self.cache.full().xlora_lock().len() {
582                    new_cache.push(None);
583                }
584
585                self.cache.full().xlora_lock().clone_from(&new_cache);
586            }
587            self.cache.full().xlora_lock()
588        } else {
589            self.cache.full().lock()
590        };
591        let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
592            input_ids,
593            &*cache,
594            self.sliding_window,
595            xs.dtype(),
596            self.cfg.num_attn_heads,
597        )?;
598
599        for (i, layer) in self.layers.iter().enumerate() {
600            xs = self.mapper.map(xs, i)?;
601            xs = layer.forward(
602                &xs,
603                attention_mask
604                    .as_ref()
605                    .map(|m| m.to_device(xs.device()).unwrap())
606                    .as_ref(),
607                seqlen_offsets,
608                &mut cache[i],
609                scalings.clone(),
610                self.xlora_classifier
611                    .as_ref()
612                    .map(|classifier| classifier.get_global_scaling_weight())
613                    .unwrap_or(1.0),
614                is_scaling_pass,
615                flash_params,
616            )?
617        }
618        let xs = xs.to_device(&self.device)?;
619        xs.apply(&self.norm)
620    }
621
622    #[allow(clippy::too_many_arguments)]
623    pub fn forward(
624        &self,
625        input_ids: &Tensor,
626        input_ids_full: &Tensor,
627        seqlen_offsets: &[usize],
628        seqlen_offsets_full: &[usize],
629        no_kv_cache: bool,
630        non_granular_state: &Option<NonGranularState>,
631        context_lens: Vec<(usize, usize)>,
632        flash_params: &FlashParams,
633        flash_params_full: &FlashParams,
634    ) -> Result<Tensor> {
635        if self.xlora_classifier.is_some() {
636            let scalings = self.get_scalings(
637                input_ids,
638                input_ids_full,
639                seqlen_offsets,
640                seqlen_offsets_full,
641                no_kv_cache,
642                non_granular_state,
643                &vec![usize::MAX; context_lens.len()],
644                flash_params,
645                flash_params_full,
646            )?;
647
648            if no_kv_cache {
649                let mut res = self
650                    .inner_forward(
651                        input_ids_full,
652                        seqlen_offsets_full,
653                        Some(scalings),
654                        true,
655                        no_kv_cache,
656                        None,
657                        flash_params_full,
658                    )?
659                    .contiguous()?;
660                if let Some(t) = self.lm_head.quantized_act_type() {
661                    res = res.to_dtype(t)?;
662                }
663                extract_logits(
664                    &self.lm_head.lora_forward(&res, None, 1.0, None)?,
665                    context_lens,
666                )
667            } else {
668                // is_full_pass=true is ok because no_kv_cache=false
669                let mut res = self
670                    .inner_forward(
671                        input_ids,
672                        seqlen_offsets,
673                        Some(scalings),
674                        true,
675                        no_kv_cache,
676                        None,
677                        flash_params,
678                    )?
679                    .contiguous()?;
680                if let Some(t) = self.lm_head.quantized_act_type() {
681                    res = res.to_dtype(t)?;
682                }
683                extract_logits(
684                    &self.lm_head.lora_forward(&res, None, 1.0, None)?,
685                    context_lens,
686                )
687            }
688        } else {
689            let mut res = self
690                .inner_forward(
691                    input_ids,
692                    seqlen_offsets,
693                    None,
694                    false,
695                    no_kv_cache,
696                    None,
697                    flash_params,
698                )?
699                .contiguous()?;
700            if let Some(t) = self.lm_head.quantized_act_type() {
701                res = res.to_dtype(t)?;
702            }
703            extract_logits(
704                &self.lm_head.lora_forward(&res, None, 1.0, None)?,
705                context_lens,
706            )
707        }
708    }
709}
710
711impl IsqModel for Model {
712    fn get_layers(
713        &mut self,
714    ) -> (
715        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
716        &dyn DeviceMapper,
717    ) {
718        let mut tensors = Vec::new();
719        tensors.push((Arc::get_mut(&mut self.lm_head).unwrap().quant_inner(), None));
720        for (i, layer) in self.layers.iter_mut().enumerate() {
721            tensors.push((
722                Arc::get_mut(&mut layer.self_attn.q_proj)
723                    .unwrap()
724                    .quant_inner(),
725                Some(i),
726            ));
727            tensors.push((
728                Arc::get_mut(&mut layer.self_attn.k_proj)
729                    .unwrap()
730                    .quant_inner(),
731                Some(i),
732            ));
733            tensors.push((
734                Arc::get_mut(&mut layer.self_attn.v_proj)
735                    .unwrap()
736                    .quant_inner(),
737                Some(i),
738            ));
739            tensors.push((
740                Arc::get_mut(&mut layer.self_attn.o_proj)
741                    .unwrap()
742                    .quant_inner(),
743                Some(i),
744            ));
745            tensors.push((
746                Arc::get_mut(&mut layer.mlp.c_fc).unwrap().quant_inner(),
747                Some(i),
748            ));
749            tensors.push((
750                Arc::get_mut(&mut layer.mlp.c_proj).unwrap().quant_inner(),
751                Some(i),
752            ));
753        }
754        (tensors, &*self.mapper)
755    }
756
757    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
758        panic!("Cannot generate UQFF for an adapter model.")
759    }
760}
761
762impl NormalModel for Model {
763    fn forward(
764        &self,
765        _input_ids: &Tensor,
766        _seqlen_offsets: &[usize],
767        _context_lens: Vec<(usize, usize)>,
768        _position_ids: Vec<usize>,
769        _metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
770        _flash_params: &FlashParams,
771    ) -> Result<Tensor> {
772        unimplemented!()
773    }
774    fn xlora_forward(
775        &self,
776        input_ids: &Tensor,
777        input_ids_full: &Tensor,
778        seqlen_offsets: &[usize],
779        seqlen_offsets_full: &[usize],
780        no_kv_cache: bool,
781        non_granular_state: &Option<crate::xlora_models::NonGranularState>,
782        context_lens: Vec<(usize, usize)>,
783        _position_ids: Vec<usize>,
784        flash_params: &FlashParams,
785        flash_params_full: &FlashParams,
786    ) -> Result<Tensor> {
787        self.forward(
788            input_ids,
789            input_ids_full,
790            seqlen_offsets,
791            seqlen_offsets_full,
792            no_kv_cache,
793            non_granular_state,
794            context_lens,
795            flash_params,
796            flash_params_full,
797        )
798    }
799    fn cache(&self) -> &EitherCache {
800        &self.cache
801    }
802    fn cache_mut(&mut self) -> &mut EitherCache {
803        &mut self.cache
804    }
805    fn device(&self) -> &Device {
806        &self.device
807    }
808    fn is_xlora(&self) -> bool {
809        false
810    }
811    fn max_seq_len(&self) -> usize {
812        self.max_seq_len
813    }
814    fn activate_adapters(&mut self, adapter_names: Vec<String>) -> Result<usize> {
815        if self.xlora_classifier.is_some() {
816            candle_core::bail!("Adapter activation is not supported for X-LoRA models as the adapter set must remain the same.");
817        }
818        let mut sum = 0;
819        for layer in self.layers.iter_mut() {
820            sum += Arc::get_mut(&mut layer.self_attn.k_proj)
821                .unwrap()
822                .activate(&adapter_names)?;
823            sum += Arc::get_mut(&mut layer.self_attn.o_proj)
824                .unwrap()
825                .activate(&adapter_names)?;
826            sum += Arc::get_mut(&mut layer.self_attn.q_proj)
827                .unwrap()
828                .activate(&adapter_names)?;
829            sum += Arc::get_mut(&mut layer.self_attn.v_proj)
830                .unwrap()
831                .activate(&adapter_names)?;
832
833            sum += Arc::get_mut(&mut layer.mlp.c_fc)
834                .unwrap()
835                .activate(&adapter_names)?;
836            sum += Arc::get_mut(&mut layer.mlp.c_proj)
837                .unwrap()
838                .activate(&adapter_names)?;
839        }
840        Ok(sum)
841    }
842    fn config(&self) -> &ModelConfigMetadata {
843        &self.cfg
844    }
845}
846
847impl ScalingsMaker for Model {
848    fn dtype(&self) -> DType {
849        self.dtype
850    }
851    fn get_cache(&self) -> &EitherCache {
852        &self.cache
853    }
854    fn get_classifier(&self) -> &XLoraClassifier {
855        self.xlora_classifier.as_ref().unwrap()
856    }
857    fn forward(
858        &self,
859        input_ids: &Tensor,
860        seqlen_offsets: &[usize],
861        scalings: Tensor,
862        is_full_pass: bool,
863        no_kv_cache: bool,
864        is_scaling_pass: Option<f64>,
865        _context_lens: &[usize],
866        flash_params: &FlashParams,
867    ) -> Result<Tensor> {
868        self.inner_forward(
869            input_ids,
870            seqlen_offsets,
871            Some(scalings),
872            is_full_pass,
873            no_kv_cache,
874            is_scaling_pass,
875            flash_params,
876        )
877    }
878}
879
880impl AnyMoeBaseModelMixin for Model {}