mistralrs_core/xlora_models/
starcoder2.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use candle_core::{DType, Device, Module, Result, Tensor};
4use candle_nn::LayerNorm;
5use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
6use std::{collections::HashMap, sync::Arc};
7use tqdm::Iter;
8use tracing::info;
9
10use crate::{
11    amoe::AnyMoeBaseModelMixin,
12    attention::SdpaParams,
13    device_map::DeviceMapper,
14    layers::{self, layer_norm, Activation, CausalMasker, RotaryEmbedding, Sdpa},
15    lora::{linear_b, linear_no_bias, LinearLayerLike, LoraConfig},
16    models::starcoder2::Config,
17    paged_attention::ModelConfigMetadata,
18    pipeline::{
19        extract_logits,
20        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
21        Cache, EitherCache, IsqModel, NormalLoadingMetadata, NormalModel,
22    },
23    utils::progress::NiceProgressBar,
24    Ordering,
25};
26
27use super::{classifier::XLoraClassifier, NonGranularState, ScalingsMaker, XLoraConfig};
28
29#[derive(Clone)]
30#[allow(clippy::upper_case_acronyms)]
31struct MLP {
32    c_fc: Arc<dyn LinearLayerLike + Send + Sync>,
33    c_proj: Arc<dyn LinearLayerLike + Send + Sync>,
34    act: Activation,
35}
36
37impl MLP {
38    #[allow(clippy::too_many_arguments)]
39    fn new(
40        cfg: &Config,
41        vb: ShardedVarBuilder,
42        lora_config: &[((String, String), LoraConfig)],
43        count: &mut usize,
44        ord: &Ordering,
45        mapper: &dyn DeviceMapper,
46        layer_idx: usize,
47        loading_isq: bool,
48        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
49    ) -> Result<Self> {
50        let (h_size, i_size) = (cfg.hidden_size, cfg.intermediate_size);
51        let c_fc = linear_b(
52            h_size,
53            i_size,
54            cfg.use_bias,
55            mapper.set_device(layer_idx, vb.pp("c_fc"), loading_isq),
56            mapper.set_device(layer_idx, vb.pp("c_fc"), false),
57            lora_config,
58            count,
59            ord,
60            preload_adapters,
61        )?;
62        let c_proj = linear_b(
63            i_size,
64            h_size,
65            cfg.use_bias,
66            mapper.set_device(layer_idx, vb.pp("c_proj"), loading_isq),
67            mapper.set_device(layer_idx, vb.pp("c_proj"), false),
68            lora_config,
69            count,
70            ord,
71            preload_adapters,
72        )?;
73        Ok(Self {
74            c_fc,
75            c_proj,
76            act: cfg.hidden_act,
77        })
78    }
79
80    fn forward(
81        &self,
82        xs: &Tensor,
83        scalings: Option<Tensor>,
84        global_scaling_weight: f64,
85        is_scaling_pass: Option<f64>,
86    ) -> Result<Tensor> {
87        let original_dtype = xs.dtype();
88        let mut xs = xs.clone();
89        if let Some(t) = self.c_fc.quantized_act_type() {
90            xs = xs.to_dtype(t)?;
91        }
92        let mut res = self.c_proj.lora_forward(
93            &self
94                .c_fc
95                .lora_forward(
96                    &xs,
97                    scalings.clone(),
98                    global_scaling_weight,
99                    is_scaling_pass,
100                )?
101                .apply(&self.act)?,
102            scalings,
103            global_scaling_weight,
104            is_scaling_pass,
105        )?;
106        if self.c_fc.quantized_act_type().is_some() {
107            res = res.to_dtype(original_dtype)?;
108        }
109        Ok(res)
110    }
111}
112
113struct Attention {
114    q_proj: Arc<dyn LinearLayerLike + Send + Sync>,
115    k_proj: Arc<dyn LinearLayerLike + Send + Sync>,
116    v_proj: Arc<dyn LinearLayerLike + Send + Sync>,
117    o_proj: Arc<dyn LinearLayerLike + Send + Sync>,
118    num_heads: usize,
119    num_kv_heads: usize,
120    head_dim: usize,
121    hidden_size: usize,
122    rotary_emb: Arc<RotaryEmbedding>,
123    sliding_window: Option<usize>,
124    sdpa_params: SdpaParams,
125}
126
127impl Attention {
128    #[allow(clippy::too_many_arguments)]
129    fn new(
130        rotary_emb: Arc<RotaryEmbedding>,
131        cfg: &Config,
132        vb: ShardedVarBuilder,
133        lora_config: &[((String, String), LoraConfig)],
134        count: &mut usize,
135        ord: &Ordering,
136        mapper: &dyn DeviceMapper,
137        layer_idx: usize,
138        loading_isq: bool,
139        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
140    ) -> Result<Self> {
141        let hidden_sz = cfg.hidden_size;
142        let num_heads = cfg.num_attention_heads;
143        let num_kv_heads = cfg.num_key_value_heads;
144        let head_dim = hidden_sz / num_heads;
145        let b = cfg.use_bias;
146        let q_proj = linear_b(
147            hidden_sz,
148            num_heads * head_dim,
149            b,
150            mapper.set_device(layer_idx, vb.pp("q_proj"), loading_isq),
151            mapper.set_device(layer_idx, vb.pp("q_proj"), false),
152            lora_config,
153            count,
154            ord,
155            preload_adapters,
156        )?;
157        let k_proj = linear_b(
158            hidden_sz,
159            num_kv_heads * head_dim,
160            b,
161            mapper.set_device(layer_idx, vb.pp("k_proj"), loading_isq),
162            mapper.set_device(layer_idx, vb.pp("k_proj"), false),
163            lora_config,
164            count,
165            ord,
166            preload_adapters,
167        )?;
168        let v_proj = linear_b(
169            hidden_sz,
170            num_kv_heads * head_dim,
171            b,
172            mapper.set_device(layer_idx, vb.pp("v_proj"), loading_isq),
173            mapper.set_device(layer_idx, vb.pp("v_proj"), false),
174            lora_config,
175            count,
176            ord,
177            preload_adapters,
178        )?;
179        let o_proj = linear_b(
180            num_heads * head_dim,
181            hidden_sz,
182            b,
183            mapper.set_device(layer_idx, vb.pp("v_proj"), loading_isq),
184            mapper.set_device(layer_idx, vb.pp("v_proj"), false),
185            lora_config,
186            count,
187            ord,
188            preload_adapters,
189        )?;
190        Ok(Self {
191            q_proj,
192            k_proj,
193            v_proj,
194            o_proj,
195            num_heads,
196            num_kv_heads,
197            head_dim,
198            hidden_size: hidden_sz,
199            rotary_emb,
200            sliding_window: cfg.sliding_window,
201            sdpa_params: SdpaParams {
202                n_kv_groups: num_heads / num_kv_heads,
203                use_flash_attn: cfg.use_flash_attn,
204                softcap: None,
205                softmax_scale: 1.0 / (head_dim as f32).sqrt(),
206                sliding_window: cfg.sliding_window,
207            },
208        })
209    }
210
211    #[allow(clippy::too_many_arguments)]
212    fn forward(
213        &self,
214        xs: &Tensor,
215        attention_mask: Option<&Tensor>,
216        seqlen_offsets: &[usize],
217        kv_cache: &mut Option<(Tensor, Tensor)>,
218        scalings: Option<Tensor>,
219        global_scaling_weight: f64,
220        is_scaling_pass: Option<f64>,
221        flash_params: &FlashParams,
222    ) -> Result<Tensor> {
223        let (b_sz, q_len, _) = xs.dims3()?;
224
225        let original_dtype = xs.dtype();
226        let mut xs = xs.clone();
227        if let Some(t) = self.q_proj.quantized_act_type() {
228            xs = xs.to_dtype(t)?;
229        }
230        let mut q = self.q_proj.lora_forward(
231            &xs,
232            scalings.clone(),
233            global_scaling_weight,
234            is_scaling_pass,
235        )?;
236        let mut k = self.k_proj.lora_forward(
237            &xs,
238            scalings.clone(),
239            global_scaling_weight,
240            is_scaling_pass,
241        )?;
242        let mut v = self.v_proj.lora_forward(
243            &xs,
244            scalings.clone(),
245            global_scaling_weight,
246            is_scaling_pass,
247        )?;
248        if self.q_proj.quantized_act_type().is_some() {
249            q = q.to_dtype(original_dtype)?;
250            k = k.to_dtype(original_dtype)?;
251            v = v.to_dtype(original_dtype)?;
252        }
253
254        let (q, k, v) = if q_len != 1 {
255            let q = q
256                .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
257                .transpose(1, 2)?;
258            let k = k
259                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
260                .transpose(1, 2)?;
261            let v = v
262                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
263                .transpose(1, 2)?;
264            (q, k, v)
265        } else {
266            let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
267            let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
268            let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
269            (q, k, v)
270        };
271
272        let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
273
274        let (k, v, attn_mask) = Cache::update_kv_cache_sliding_window(
275            kv_cache,
276            k,
277            v,
278            attention_mask,
279            self.sliding_window,
280            false,
281        )?;
282
283        let mut attn_output = Sdpa.run_attention(
284            &q,
285            &k,
286            &v,
287            attn_mask.as_ref(),
288            Some(flash_params),
289            &self.sdpa_params,
290        )?;
291
292        if let Some(t) = self.q_proj.quantized_act_type() {
293            attn_output = attn_output.to_dtype(t)?;
294        }
295        let mut res = self.o_proj.lora_forward(
296            &attn_output
297                .transpose(1, 2)?
298                .reshape((b_sz, q_len, self.hidden_size))?,
299            scalings.clone(),
300            global_scaling_weight,
301            is_scaling_pass,
302        )?;
303        if self.q_proj.quantized_act_type().is_some() {
304            res = res.to_dtype(original_dtype)?;
305        }
306        Ok(res)
307    }
308}
309
310struct DecoderLayer {
311    self_attn: Attention,
312    mlp: MLP,
313    input_layernorm: LayerNorm,
314    post_attention_layernorm: LayerNorm,
315}
316
317impl DecoderLayer {
318    #[allow(clippy::too_many_arguments)]
319    fn new(
320        rotary_emb: Arc<RotaryEmbedding>,
321        cfg: &Config,
322        vb: ShardedVarBuilder,
323        lora_config: &[((String, String), LoraConfig)],
324        count: &mut usize,
325        ord: &Ordering,
326        mapper: &dyn DeviceMapper,
327        layer_idx: usize,
328        loading_isq: bool,
329        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
330    ) -> Result<Self> {
331        let self_attn = Attention::new(
332            rotary_emb,
333            cfg,
334            vb.pp("self_attn"),
335            lora_config,
336            count,
337            ord,
338            mapper,
339            layer_idx,
340            loading_isq,
341            preload_adapters,
342        )?;
343        let mlp = MLP::new(
344            cfg,
345            vb.pp("mlp"),
346            lora_config,
347            count,
348            ord,
349            mapper,
350            layer_idx,
351            loading_isq,
352            preload_adapters,
353        )?;
354        let input_layernorm = layer_norm(
355            cfg.hidden_size,
356            cfg.norm_epsilon,
357            mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
358        )?;
359        let post_attention_layernorm = layer_norm(
360            cfg.hidden_size,
361            cfg.norm_epsilon,
362            mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
363        )?;
364        Ok(Self {
365            self_attn,
366            mlp,
367            input_layernorm,
368            post_attention_layernorm,
369        })
370    }
371
372    #[allow(clippy::too_many_arguments)]
373    fn forward(
374        &self,
375        xs: &Tensor,
376        attention_mask: Option<&Tensor>,
377        seqlen_offsets: &[usize],
378        kv_cache: &mut Option<(Tensor, Tensor)>,
379        scalings: Option<Tensor>,
380        global_scaling_weight: f64,
381        is_scaling_pass: Option<f64>,
382        flash_params: &FlashParams,
383    ) -> Result<Tensor> {
384        let residual = xs;
385        let xs = self.input_layernorm.forward(xs)?;
386        let xs = self.self_attn.forward(
387            &xs,
388            attention_mask,
389            seqlen_offsets,
390            kv_cache,
391            scalings.clone(),
392            global_scaling_weight,
393            is_scaling_pass,
394            flash_params,
395        )?;
396        let xs = (xs + residual)?;
397        let residual = &xs;
398        let xs = self.mlp.forward(
399            &xs.apply(&self.post_attention_layernorm)?,
400            scalings.clone(),
401            global_scaling_weight,
402            is_scaling_pass,
403        )?;
404        residual + xs
405    }
406}
407
408pub struct Model {
409    embed_tokens: candle_nn::Embedding,
410    layers: Vec<DecoderLayer>,
411    norm: LayerNorm,
412    lm_head: Arc<dyn LinearLayerLike + Send + Sync>,
413    sliding_window: Option<usize>,
414    device: Device,
415    cache: EitherCache,
416    max_seq_len: usize,
417    mapper: Box<dyn DeviceMapper + Send + Sync>,
418    xlora_classifier: Option<XLoraClassifier>,
419    dtype: DType,
420    cfg: ModelConfigMetadata,
421}
422
423impl Model {
424    #[allow(clippy::too_many_arguments)]
425    pub fn new(
426        cfg: &Config,
427        vb: ShardedVarBuilder,
428        lora_config: &[((String, String), LoraConfig)],
429        xlora_config: Option<XLoraConfig>,
430        xlora_ordering: Ordering,
431        is_gptx: bool,
432        normal_loading_metadata: NormalLoadingMetadata,
433        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
434    ) -> Result<Self> {
435        if let Some(ref quant_cfg) = &cfg.quantization_config {
436            tracing::info!(
437                "Using {} quantization: {}.",
438                quant_cfg.name(),
439                quant_cfg.get_bits_name(&vb)
440            );
441        }
442        let mapper = normal_loading_metadata.mapper;
443        let vb_m = vb.pp("model");
444
445        let embed_tokens = layers::embedding(
446            cfg.vocab_size,
447            cfg.hidden_size,
448            mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
449            &cfg.quantization_config,
450        )?;
451        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
452        let vb_l = vb_m.pp("layers");
453        let head_dim = cfg.hidden_size / cfg.num_attention_heads;
454        let mut ropes = HashMap::new();
455        for layer_idx in 0..cfg.num_hidden_layers {
456            let device = mapper
457                .device_for(layer_idx, false)
458                .unwrap_or(&normal_loading_metadata.real_device);
459            ropes.insert(
460                device.location(),
461                Arc::new(RotaryEmbedding::new(
462                    cfg.rope_theta as f32,
463                    head_dim,
464                    cfg.max_position_embeddings,
465                    device,
466                    is_gptx,
467                    vb_m.dtype(),
468                )?),
469            );
470        }
471        let mut count = 0;
472        for layer_idx in NiceProgressBar::<_, 'b'>(
473            0..cfg.num_hidden_layers,
474            "Loading repeating layers",
475            &normal_loading_metadata.multi_progress,
476        ) {
477            let device = mapper
478                .device_for(layer_idx, false)
479                .unwrap_or(&normal_loading_metadata.real_device);
480            let rotary_emb = ropes
481                .get(&device.location())
482                .expect("No RoPE for device location!")
483                .clone();
484            layers.push(DecoderLayer::new(
485                rotary_emb.clone(),
486                cfg,
487                vb_l.pp(layer_idx),
488                lora_config,
489                &mut count,
490                &xlora_ordering,
491                &*mapper,
492                layer_idx,
493                normal_loading_metadata.loading_isq,
494                preload_adapters,
495            )?)
496        }
497        if xlora_config.is_none() && preload_adapters.is_none() {
498            // We are now a LoRA model so we must merge the weights
499            info!("Merging LoRA adapters.");
500            for layer in layers.iter_mut().tqdm() {
501                Arc::get_mut(&mut layer.self_attn.k_proj)
502                    .unwrap()
503                    .merge_weights()?;
504                Arc::get_mut(&mut layer.self_attn.o_proj)
505                    .unwrap()
506                    .merge_weights()?;
507                Arc::get_mut(&mut layer.self_attn.q_proj)
508                    .unwrap()
509                    .merge_weights()?;
510                Arc::get_mut(&mut layer.self_attn.v_proj)
511                    .unwrap()
512                    .merge_weights()?;
513
514                Arc::get_mut(&mut layer.mlp.c_fc).unwrap().merge_weights()?;
515                Arc::get_mut(&mut layer.mlp.c_proj)
516                    .unwrap()
517                    .merge_weights()?;
518            }
519        }
520        let norm = layer_norm(
521            cfg.hidden_size,
522            cfg.norm_epsilon,
523            mapper.set_nm_device(vb_m.pp("norm"), false),
524        )?;
525        let lm_head = linear_no_bias(
526            embed_tokens.embeddings().dim(1)?,
527            embed_tokens.embeddings().dim(0)?,
528            mapper.set_nm_device(vb_m.pp("embed_tokens"), normal_loading_metadata.loading_isq),
529            mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
530            lora_config,
531            &mut count,
532            &xlora_ordering,
533            preload_adapters,
534        )?;
535        if xlora_config.is_some() && lm_head.is_lora() {
536            // This is why we can pass dummy values (..., None, 1.0, None)?
537            candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA.");
538        }
539        Ok(Self {
540            embed_tokens,
541            layers,
542            norm,
543            lm_head,
544            sliding_window: cfg.sliding_window,
545            device: normal_loading_metadata.real_device,
546            cache: EitherCache::Full(Cache::new(cfg.num_hidden_layers, true)),
547            max_seq_len: cfg.max_position_embeddings,
548            mapper,
549            dtype: vb.dtype(),
550            xlora_classifier: xlora_config.map(|xlora_config| {
551                XLoraClassifier::new(xlora_config, count, lora_config.len(), vb, false).unwrap()
552            }),
553            cfg: ModelConfigMetadata {
554                max_seq_len: cfg.max_position_embeddings,
555                num_layers: cfg.num_hidden_layers,
556                hidden_size: cfg.hidden_size,
557                num_kv_heads: cfg.num_key_value_heads,
558                num_attn_heads: cfg.num_attention_heads,
559                sliding_window: cfg.sliding_window,
560                k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
561                v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
562            },
563        })
564    }
565
566    #[allow(clippy::too_many_arguments)]
567    fn inner_forward(
568        &self,
569        input_ids: &Tensor,
570        seqlen_offsets: &[usize],
571        scalings: Option<Tensor>,
572        is_full_pass: bool,
573        no_kv_cache: bool,
574        is_scaling_pass: Option<f64>,
575        flash_params: &FlashParams,
576    ) -> Result<Tensor> {
577        let mut xs = self.embed_tokens.forward(input_ids)?;
578
579        let mut cache = if is_full_pass {
580            if no_kv_cache {
581                let mut new_cache = Vec::new();
582                for _ in 0..self.cache.full().xlora_lock().len() {
583                    new_cache.push(None);
584                }
585
586                self.cache.full().xlora_lock().clone_from(&new_cache);
587            }
588            self.cache.full().xlora_lock()
589        } else {
590            self.cache.full().lock()
591        };
592        let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
593            input_ids,
594            &*cache,
595            self.sliding_window,
596            xs.dtype(),
597            self.cfg.num_attn_heads,
598        )?;
599
600        for (i, layer) in self.layers.iter().enumerate() {
601            xs = self.mapper.map(xs, i)?;
602            xs = layer.forward(
603                &xs,
604                attention_mask
605                    .as_ref()
606                    .map(|m| m.to_device(xs.device()).unwrap())
607                    .as_ref(),
608                seqlen_offsets,
609                &mut cache[i],
610                scalings.clone(),
611                self.xlora_classifier
612                    .as_ref()
613                    .map(|classifier| classifier.get_global_scaling_weight())
614                    .unwrap_or(1.0),
615                is_scaling_pass,
616                flash_params,
617            )?
618        }
619        let xs = xs.to_device(&self.device)?;
620        xs.apply(&self.norm)
621    }
622
623    #[allow(clippy::too_many_arguments)]
624    pub fn forward(
625        &self,
626        input_ids: &Tensor,
627        input_ids_full: &Tensor,
628        seqlen_offsets: &[usize],
629        seqlen_offsets_full: &[usize],
630        no_kv_cache: bool,
631        non_granular_state: &Option<NonGranularState>,
632        context_lens: Vec<(usize, usize)>,
633        flash_params: &FlashParams,
634        flash_params_full: &FlashParams,
635    ) -> Result<Tensor> {
636        if self.xlora_classifier.is_some() {
637            let scalings = self.get_scalings(
638                input_ids,
639                input_ids_full,
640                seqlen_offsets,
641                seqlen_offsets_full,
642                no_kv_cache,
643                non_granular_state,
644                &vec![usize::MAX; context_lens.len()],
645                flash_params,
646                flash_params_full,
647            )?;
648
649            if no_kv_cache {
650                let mut res = self
651                    .inner_forward(
652                        input_ids_full,
653                        seqlen_offsets_full,
654                        Some(scalings),
655                        true,
656                        no_kv_cache,
657                        None,
658                        flash_params_full,
659                    )?
660                    .contiguous()?;
661                if let Some(t) = self.lm_head.quantized_act_type() {
662                    res = res.to_dtype(t)?;
663                }
664                extract_logits(
665                    &self.lm_head.lora_forward(&res, None, 1.0, None)?,
666                    context_lens,
667                )
668            } else {
669                // is_full_pass=true is ok because no_kv_cache=false
670                let mut res = self
671                    .inner_forward(
672                        input_ids,
673                        seqlen_offsets,
674                        Some(scalings),
675                        true,
676                        no_kv_cache,
677                        None,
678                        flash_params,
679                    )?
680                    .contiguous()?;
681                if let Some(t) = self.lm_head.quantized_act_type() {
682                    res = res.to_dtype(t)?;
683                }
684                extract_logits(
685                    &self.lm_head.lora_forward(&res, None, 1.0, None)?,
686                    context_lens,
687                )
688            }
689        } else {
690            let mut res = self
691                .inner_forward(
692                    input_ids,
693                    seqlen_offsets,
694                    None,
695                    false,
696                    no_kv_cache,
697                    None,
698                    flash_params,
699                )?
700                .contiguous()?;
701            if let Some(t) = self.lm_head.quantized_act_type() {
702                res = res.to_dtype(t)?;
703            }
704            extract_logits(
705                &self.lm_head.lora_forward(&res, None, 1.0, None)?,
706                context_lens,
707            )
708        }
709    }
710}
711
712impl IsqModel for Model {
713    fn get_layers(
714        &mut self,
715    ) -> (
716        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
717        &dyn DeviceMapper,
718    ) {
719        let mut tensors = Vec::new();
720        tensors.push((Arc::get_mut(&mut self.lm_head).unwrap().quant_inner(), None));
721        for (i, layer) in self.layers.iter_mut().enumerate() {
722            tensors.push((
723                Arc::get_mut(&mut layer.self_attn.q_proj)
724                    .unwrap()
725                    .quant_inner(),
726                Some(i),
727            ));
728            tensors.push((
729                Arc::get_mut(&mut layer.self_attn.k_proj)
730                    .unwrap()
731                    .quant_inner(),
732                Some(i),
733            ));
734            tensors.push((
735                Arc::get_mut(&mut layer.self_attn.v_proj)
736                    .unwrap()
737                    .quant_inner(),
738                Some(i),
739            ));
740            tensors.push((
741                Arc::get_mut(&mut layer.self_attn.o_proj)
742                    .unwrap()
743                    .quant_inner(),
744                Some(i),
745            ));
746            tensors.push((
747                Arc::get_mut(&mut layer.mlp.c_fc).unwrap().quant_inner(),
748                Some(i),
749            ));
750            tensors.push((
751                Arc::get_mut(&mut layer.mlp.c_proj).unwrap().quant_inner(),
752                Some(i),
753            ));
754        }
755        (tensors, &*self.mapper)
756    }
757
758    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
759        panic!("Cannot generate UQFF for an adapter model.")
760    }
761}
762
763impl NormalModel for Model {
764    fn forward(
765        &self,
766        _input_ids: &Tensor,
767        _seqlen_offsets: &[usize],
768        _context_lens: Vec<(usize, usize)>,
769        _position_ids: Vec<usize>,
770        _metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
771        _flash_params: &FlashParams,
772    ) -> Result<Tensor> {
773        unimplemented!()
774    }
775    fn xlora_forward(
776        &self,
777        input_ids: &Tensor,
778        input_ids_full: &Tensor,
779        seqlen_offsets: &[usize],
780        seqlen_offsets_full: &[usize],
781        no_kv_cache: bool,
782        non_granular_state: &Option<crate::xlora_models::NonGranularState>,
783        context_lens: Vec<(usize, usize)>,
784        _position_ids: Vec<usize>,
785        flash_params: &FlashParams,
786        flash_params_full: &FlashParams,
787    ) -> Result<Tensor> {
788        self.forward(
789            input_ids,
790            input_ids_full,
791            seqlen_offsets,
792            seqlen_offsets_full,
793            no_kv_cache,
794            non_granular_state,
795            context_lens,
796            flash_params,
797            flash_params_full,
798        )
799    }
800    fn cache(&self) -> &EitherCache {
801        &self.cache
802    }
803    fn cache_mut(&mut self) -> &mut EitherCache {
804        &mut self.cache
805    }
806    fn device(&self) -> &Device {
807        &self.device
808    }
809    fn is_xlora(&self) -> bool {
810        false
811    }
812    fn max_seq_len(&self) -> usize {
813        self.max_seq_len
814    }
815    fn config(&self) -> &ModelConfigMetadata {
816        &self.cfg
817    }
818}
819
820impl ScalingsMaker for Model {
821    fn dtype(&self) -> DType {
822        self.dtype
823    }
824    fn get_cache(&self) -> &EitherCache {
825        &self.cache
826    }
827    fn get_classifier(&self) -> &XLoraClassifier {
828        self.xlora_classifier.as_ref().unwrap()
829    }
830    fn forward(
831        &self,
832        input_ids: &Tensor,
833        seqlen_offsets: &[usize],
834        scalings: Tensor,
835        is_full_pass: bool,
836        no_kv_cache: bool,
837        is_scaling_pass: Option<f64>,
838        _context_lens: &[usize],
839        flash_params: &FlashParams,
840    ) -> Result<Tensor> {
841        self.inner_forward(
842            input_ids,
843            seqlen_offsets,
844            Some(scalings),
845            is_full_pass,
846            no_kv_cache,
847            is_scaling_pass,
848            flash_params,
849        )
850    }
851}
852
853impl AnyMoeBaseModelMixin for Model {}