mistralrs_core/xlora_models/
starcoder2.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use candle_core::{DType, Device, Module, Result, Tensor};
4use candle_nn::LayerNorm;
5use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
6use std::{collections::HashMap, sync::Arc};
7use tqdm::Iter;
8use tracing::info;
9
10use crate::{
11    amoe::AnyMoeBaseModelMixin,
12    attention::SdpaParams,
13    device_map::DeviceMapper,
14    layers::{self, layer_norm, Activation, CausalMasker, RotaryEmbedding, Sdpa},
15    lora::{linear_b, linear_no_bias, LinearLayerLike, LoraConfig},
16    models::starcoder2::Config,
17    paged_attention::ModelConfigMetadata,
18    pipeline::{
19        extract_logits,
20        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
21        Cache, EitherCache, IsqModel, NormalLoadingMetadata, NormalModel,
22    },
23    utils::progress::NiceProgressBar,
24    Ordering,
25};
26
27use super::{classifier::XLoraClassifier, NonGranularState, ScalingsMaker, XLoraConfig};
28
29#[derive(Clone)]
30#[allow(clippy::upper_case_acronyms)]
31struct MLP {
32    c_fc: Arc<dyn LinearLayerLike + Send + Sync>,
33    c_proj: Arc<dyn LinearLayerLike + Send + Sync>,
34    act: Activation,
35}
36
37impl MLP {
38    #[allow(clippy::too_many_arguments)]
39    fn new(
40        cfg: &Config,
41        vb: ShardedVarBuilder,
42        lora_config: &[((String, String), LoraConfig)],
43        count: &mut usize,
44        ord: &Ordering,
45        mapper: &dyn DeviceMapper,
46        layer_idx: usize,
47        loading_isq: bool,
48        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
49    ) -> Result<Self> {
50        let (h_size, i_size) = (cfg.hidden_size, cfg.intermediate_size);
51        let c_fc = linear_b(
52            h_size,
53            i_size,
54            cfg.use_bias,
55            mapper.set_device(layer_idx, vb.pp("c_fc"), loading_isq),
56            mapper.set_device(layer_idx, vb.pp("c_fc"), false),
57            lora_config,
58            count,
59            ord,
60            preload_adapters,
61        )?;
62        let c_proj = linear_b(
63            i_size,
64            h_size,
65            cfg.use_bias,
66            mapper.set_device(layer_idx, vb.pp("c_proj"), loading_isq),
67            mapper.set_device(layer_idx, vb.pp("c_proj"), false),
68            lora_config,
69            count,
70            ord,
71            preload_adapters,
72        )?;
73        Ok(Self {
74            c_fc,
75            c_proj,
76            act: cfg.hidden_act,
77        })
78    }
79
80    fn forward(
81        &self,
82        xs: &Tensor,
83        scalings: Option<Tensor>,
84        global_scaling_weight: f64,
85        is_scaling_pass: Option<f64>,
86    ) -> Result<Tensor> {
87        let original_dtype = xs.dtype();
88        let mut xs = xs.clone();
89        if let Some(t) = self.c_fc.quantized_act_type() {
90            xs = xs.to_dtype(t)?;
91        }
92        let mut res = self.c_proj.lora_forward(
93            &self
94                .c_fc
95                .lora_forward(
96                    &xs,
97                    scalings.clone(),
98                    global_scaling_weight,
99                    is_scaling_pass,
100                )?
101                .apply(&self.act)?,
102            scalings,
103            global_scaling_weight,
104            is_scaling_pass,
105        )?;
106        if self.c_fc.quantized_act_type().is_some() {
107            res = res.to_dtype(original_dtype)?;
108        }
109        Ok(res)
110    }
111}
112
113struct Attention {
114    q_proj: Arc<dyn LinearLayerLike + Send + Sync>,
115    k_proj: Arc<dyn LinearLayerLike + Send + Sync>,
116    v_proj: Arc<dyn LinearLayerLike + Send + Sync>,
117    o_proj: Arc<dyn LinearLayerLike + Send + Sync>,
118    num_heads: usize,
119    num_kv_heads: usize,
120    head_dim: usize,
121    hidden_size: usize,
122    rotary_emb: Arc<RotaryEmbedding>,
123    sliding_window: Option<usize>,
124    sdpa_params: SdpaParams,
125}
126
127impl Attention {
128    #[allow(clippy::too_many_arguments)]
129    fn new(
130        rotary_emb: Arc<RotaryEmbedding>,
131        cfg: &Config,
132        vb: ShardedVarBuilder,
133        lora_config: &[((String, String), LoraConfig)],
134        count: &mut usize,
135        ord: &Ordering,
136        mapper: &dyn DeviceMapper,
137        layer_idx: usize,
138        loading_isq: bool,
139        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
140    ) -> Result<Self> {
141        let hidden_sz = cfg.hidden_size;
142        let num_heads = cfg.num_attention_heads;
143        let num_kv_heads = cfg.num_key_value_heads;
144        let head_dim = hidden_sz / num_heads;
145        let b = cfg.use_bias;
146        let q_proj = linear_b(
147            hidden_sz,
148            num_heads * head_dim,
149            b,
150            mapper.set_device(layer_idx, vb.pp("q_proj"), loading_isq),
151            mapper.set_device(layer_idx, vb.pp("q_proj"), false),
152            lora_config,
153            count,
154            ord,
155            preload_adapters,
156        )?;
157        let k_proj = linear_b(
158            hidden_sz,
159            num_kv_heads * head_dim,
160            b,
161            mapper.set_device(layer_idx, vb.pp("k_proj"), loading_isq),
162            mapper.set_device(layer_idx, vb.pp("k_proj"), false),
163            lora_config,
164            count,
165            ord,
166            preload_adapters,
167        )?;
168        let v_proj = linear_b(
169            hidden_sz,
170            num_kv_heads * head_dim,
171            b,
172            mapper.set_device(layer_idx, vb.pp("v_proj"), loading_isq),
173            mapper.set_device(layer_idx, vb.pp("v_proj"), false),
174            lora_config,
175            count,
176            ord,
177            preload_adapters,
178        )?;
179        let o_proj = linear_b(
180            num_heads * head_dim,
181            hidden_sz,
182            b,
183            mapper.set_device(layer_idx, vb.pp("v_proj"), loading_isq),
184            mapper.set_device(layer_idx, vb.pp("v_proj"), false),
185            lora_config,
186            count,
187            ord,
188            preload_adapters,
189        )?;
190        Ok(Self {
191            q_proj,
192            k_proj,
193            v_proj,
194            o_proj,
195            num_heads,
196            num_kv_heads,
197            head_dim,
198            hidden_size: hidden_sz,
199            rotary_emb,
200            sliding_window: cfg.sliding_window,
201            sdpa_params: SdpaParams {
202                n_kv_groups: num_heads / num_kv_heads,
203                softcap: None,
204                softmax_scale: 1.0 / (head_dim as f32).sqrt(),
205                sliding_window: cfg.sliding_window,
206            },
207        })
208    }
209
210    #[allow(clippy::too_many_arguments)]
211    fn forward(
212        &self,
213        xs: &Tensor,
214        attention_mask: Option<&Tensor>,
215        seqlen_offsets: &[usize],
216        kv_cache: &mut Option<(Tensor, Tensor)>,
217        scalings: Option<Tensor>,
218        global_scaling_weight: f64,
219        is_scaling_pass: Option<f64>,
220        flash_params: &FlashParams,
221    ) -> Result<Tensor> {
222        let (b_sz, q_len, _) = xs.dims3()?;
223
224        let original_dtype = xs.dtype();
225        let mut xs = xs.clone();
226        if let Some(t) = self.q_proj.quantized_act_type() {
227            xs = xs.to_dtype(t)?;
228        }
229        let mut q = self.q_proj.lora_forward(
230            &xs,
231            scalings.clone(),
232            global_scaling_weight,
233            is_scaling_pass,
234        )?;
235        let mut k = self.k_proj.lora_forward(
236            &xs,
237            scalings.clone(),
238            global_scaling_weight,
239            is_scaling_pass,
240        )?;
241        let mut v = self.v_proj.lora_forward(
242            &xs,
243            scalings.clone(),
244            global_scaling_weight,
245            is_scaling_pass,
246        )?;
247        if self.q_proj.quantized_act_type().is_some() {
248            q = q.to_dtype(original_dtype)?;
249            k = k.to_dtype(original_dtype)?;
250            v = v.to_dtype(original_dtype)?;
251        }
252
253        let (q, k, v) = if q_len != 1 {
254            let q = q
255                .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
256                .transpose(1, 2)?;
257            let k = k
258                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
259                .transpose(1, 2)?;
260            let v = v
261                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
262                .transpose(1, 2)?;
263            (q, k, v)
264        } else {
265            let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
266            let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
267            let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
268            (q, k, v)
269        };
270
271        let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
272
273        let (k, v, attn_mask) = Cache::update_kv_cache_sliding_window(
274            kv_cache,
275            k,
276            v,
277            attention_mask,
278            self.sliding_window,
279        )?;
280
281        let mut attn_output = Sdpa.run_attention(
282            &q,
283            &k,
284            &v,
285            attn_mask.as_ref(),
286            Some(flash_params),
287            &self.sdpa_params,
288        )?;
289
290        if let Some(t) = self.q_proj.quantized_act_type() {
291            attn_output = attn_output.to_dtype(t)?;
292        }
293        let mut res = self.o_proj.lora_forward(
294            &attn_output
295                .transpose(1, 2)?
296                .reshape((b_sz, q_len, self.hidden_size))?,
297            scalings.clone(),
298            global_scaling_weight,
299            is_scaling_pass,
300        )?;
301        if self.q_proj.quantized_act_type().is_some() {
302            res = res.to_dtype(original_dtype)?;
303        }
304        Ok(res)
305    }
306}
307
308struct DecoderLayer {
309    self_attn: Attention,
310    mlp: MLP,
311    input_layernorm: LayerNorm,
312    post_attention_layernorm: LayerNorm,
313}
314
315impl DecoderLayer {
316    #[allow(clippy::too_many_arguments)]
317    fn new(
318        rotary_emb: Arc<RotaryEmbedding>,
319        cfg: &Config,
320        vb: ShardedVarBuilder,
321        lora_config: &[((String, String), LoraConfig)],
322        count: &mut usize,
323        ord: &Ordering,
324        mapper: &dyn DeviceMapper,
325        layer_idx: usize,
326        loading_isq: bool,
327        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
328    ) -> Result<Self> {
329        let self_attn = Attention::new(
330            rotary_emb,
331            cfg,
332            vb.pp("self_attn"),
333            lora_config,
334            count,
335            ord,
336            mapper,
337            layer_idx,
338            loading_isq,
339            preload_adapters,
340        )?;
341        let mlp = MLP::new(
342            cfg,
343            vb.pp("mlp"),
344            lora_config,
345            count,
346            ord,
347            mapper,
348            layer_idx,
349            loading_isq,
350            preload_adapters,
351        )?;
352        let input_layernorm = layer_norm(
353            cfg.hidden_size,
354            cfg.norm_epsilon,
355            mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
356        )?;
357        let post_attention_layernorm = layer_norm(
358            cfg.hidden_size,
359            cfg.norm_epsilon,
360            mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
361        )?;
362        Ok(Self {
363            self_attn,
364            mlp,
365            input_layernorm,
366            post_attention_layernorm,
367        })
368    }
369
370    #[allow(clippy::too_many_arguments)]
371    fn forward(
372        &self,
373        xs: &Tensor,
374        attention_mask: Option<&Tensor>,
375        seqlen_offsets: &[usize],
376        kv_cache: &mut Option<(Tensor, Tensor)>,
377        scalings: Option<Tensor>,
378        global_scaling_weight: f64,
379        is_scaling_pass: Option<f64>,
380        flash_params: &FlashParams,
381    ) -> Result<Tensor> {
382        let residual = xs;
383        let xs = self.input_layernorm.forward(xs)?;
384        let xs = self.self_attn.forward(
385            &xs,
386            attention_mask,
387            seqlen_offsets,
388            kv_cache,
389            scalings.clone(),
390            global_scaling_weight,
391            is_scaling_pass,
392            flash_params,
393        )?;
394        let xs = (xs + residual)?;
395        let residual = &xs;
396        let xs = self.mlp.forward(
397            &xs.apply(&self.post_attention_layernorm)?,
398            scalings.clone(),
399            global_scaling_weight,
400            is_scaling_pass,
401        )?;
402        residual + xs
403    }
404}
405
406pub struct Model {
407    embed_tokens: candle_nn::Embedding,
408    layers: Vec<DecoderLayer>,
409    norm: LayerNorm,
410    lm_head: Arc<dyn LinearLayerLike + Send + Sync>,
411    sliding_window: Option<usize>,
412    device: Device,
413    cache: EitherCache,
414    max_seq_len: usize,
415    mapper: Box<dyn DeviceMapper + Send + Sync>,
416    xlora_classifier: Option<XLoraClassifier>,
417    dtype: DType,
418    cfg: ModelConfigMetadata,
419}
420
421impl Model {
422    #[allow(clippy::too_many_arguments)]
423    pub fn new(
424        cfg: &Config,
425        vb: ShardedVarBuilder,
426        lora_config: &[((String, String), LoraConfig)],
427        xlora_config: Option<XLoraConfig>,
428        xlora_ordering: Ordering,
429        is_gptx: bool,
430        normal_loading_metadata: NormalLoadingMetadata,
431        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
432    ) -> Result<Self> {
433        if let Some(ref quant_cfg) = &cfg.quantization_config {
434            tracing::info!(
435                "Using {} quantization: {}.",
436                quant_cfg.name(),
437                quant_cfg.get_bits_name(&vb)
438            );
439        }
440        let mapper = normal_loading_metadata.mapper;
441        let vb_m = vb.pp("model");
442
443        let embed_tokens = layers::embedding(
444            cfg.vocab_size,
445            cfg.hidden_size,
446            mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
447            &cfg.quantization_config,
448        )?;
449        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
450        let vb_l = vb_m.pp("layers");
451        let head_dim = cfg.hidden_size / cfg.num_attention_heads;
452        let mut ropes = HashMap::new();
453        for layer_idx in 0..cfg.num_hidden_layers {
454            let device = mapper
455                .device_for(layer_idx, false)
456                .unwrap_or(&normal_loading_metadata.real_device);
457            ropes.insert(
458                device.location(),
459                Arc::new(RotaryEmbedding::new(
460                    cfg.rope_theta as f32,
461                    head_dim,
462                    cfg.max_position_embeddings,
463                    device,
464                    is_gptx,
465                    vb_m.dtype(),
466                )?),
467            );
468        }
469        let mut count = 0;
470        for layer_idx in NiceProgressBar::<_, 'b'>(
471            0..cfg.num_hidden_layers,
472            "Loading repeating layers",
473            &normal_loading_metadata.multi_progress,
474        ) {
475            let device = mapper
476                .device_for(layer_idx, false)
477                .unwrap_or(&normal_loading_metadata.real_device);
478            let rotary_emb = ropes
479                .get(&device.location())
480                .expect("No RoPE for device location!")
481                .clone();
482            layers.push(DecoderLayer::new(
483                rotary_emb.clone(),
484                cfg,
485                vb_l.pp(layer_idx),
486                lora_config,
487                &mut count,
488                &xlora_ordering,
489                &*mapper,
490                layer_idx,
491                normal_loading_metadata.loading_isq,
492                preload_adapters,
493            )?)
494        }
495        if xlora_config.is_none() && preload_adapters.is_none() {
496            // We are now a LoRA model so we must merge the weights
497            info!("Merging LoRA adapters.");
498            for layer in layers.iter_mut().tqdm() {
499                Arc::get_mut(&mut layer.self_attn.k_proj)
500                    .unwrap()
501                    .merge_weights()?;
502                Arc::get_mut(&mut layer.self_attn.o_proj)
503                    .unwrap()
504                    .merge_weights()?;
505                Arc::get_mut(&mut layer.self_attn.q_proj)
506                    .unwrap()
507                    .merge_weights()?;
508                Arc::get_mut(&mut layer.self_attn.v_proj)
509                    .unwrap()
510                    .merge_weights()?;
511
512                Arc::get_mut(&mut layer.mlp.c_fc).unwrap().merge_weights()?;
513                Arc::get_mut(&mut layer.mlp.c_proj)
514                    .unwrap()
515                    .merge_weights()?;
516            }
517        }
518        let norm = layer_norm(
519            cfg.hidden_size,
520            cfg.norm_epsilon,
521            mapper.set_nm_device(vb_m.pp("norm"), false),
522        )?;
523        let lm_head = linear_no_bias(
524            embed_tokens.embeddings().dim(1)?,
525            embed_tokens.embeddings().dim(0)?,
526            mapper.set_nm_device(vb_m.pp("embed_tokens"), normal_loading_metadata.loading_isq),
527            mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
528            lora_config,
529            &mut count,
530            &xlora_ordering,
531            preload_adapters,
532        )?;
533        if xlora_config.is_some() && lm_head.is_lora() {
534            // This is why we can pass dummy values (..., None, 1.0, None)?
535            candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA.");
536        }
537        Ok(Self {
538            embed_tokens,
539            layers,
540            norm,
541            lm_head,
542            sliding_window: cfg.sliding_window,
543            device: normal_loading_metadata.real_device,
544            cache: EitherCache::Full(Cache::new(cfg.num_hidden_layers, true)),
545            max_seq_len: cfg.max_position_embeddings,
546            mapper,
547            dtype: vb.dtype(),
548            xlora_classifier: xlora_config.map(|xlora_config| {
549                XLoraClassifier::new(xlora_config, count, lora_config.len(), vb, false).unwrap()
550            }),
551            cfg: ModelConfigMetadata {
552                max_seq_len: cfg.max_position_embeddings,
553                num_layers: cfg.num_hidden_layers,
554                hidden_size: cfg.hidden_size,
555                num_kv_heads: cfg.num_key_value_heads,
556                num_attn_heads: cfg.num_attention_heads,
557                sliding_window: cfg.sliding_window,
558                k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
559                v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
560            },
561        })
562    }
563
564    #[allow(clippy::too_many_arguments)]
565    fn inner_forward(
566        &self,
567        input_ids: &Tensor,
568        seqlen_offsets: &[usize],
569        scalings: Option<Tensor>,
570        is_full_pass: bool,
571        no_kv_cache: bool,
572        is_scaling_pass: Option<f64>,
573        flash_params: &FlashParams,
574    ) -> Result<Tensor> {
575        let mut xs = self.embed_tokens.forward(input_ids)?;
576
577        let mut cache = if is_full_pass {
578            if no_kv_cache {
579                let mut new_cache = Vec::new();
580                for _ in 0..self.cache.full().xlora_lock().len() {
581                    new_cache.push(None);
582                }
583
584                self.cache.full().xlora_lock().clone_from(&new_cache);
585            }
586            self.cache.full().xlora_lock()
587        } else {
588            self.cache.full().lock()
589        };
590        let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
591            input_ids,
592            &*cache,
593            self.sliding_window,
594            xs.dtype(),
595            self.cfg.num_attn_heads,
596        )?;
597
598        for (i, layer) in self.layers.iter().enumerate() {
599            xs = self.mapper.map(xs, i)?;
600            xs = layer.forward(
601                &xs,
602                attention_mask
603                    .as_ref()
604                    .map(|m| m.to_device(xs.device()).unwrap())
605                    .as_ref(),
606                seqlen_offsets,
607                &mut cache[i],
608                scalings.clone(),
609                self.xlora_classifier
610                    .as_ref()
611                    .map(|classifier| classifier.get_global_scaling_weight())
612                    .unwrap_or(1.0),
613                is_scaling_pass,
614                flash_params,
615            )?
616        }
617        let xs = xs.to_device(&self.device)?;
618        xs.apply(&self.norm)
619    }
620
621    #[allow(clippy::too_many_arguments)]
622    pub fn forward(
623        &self,
624        input_ids: &Tensor,
625        input_ids_full: &Tensor,
626        seqlen_offsets: &[usize],
627        seqlen_offsets_full: &[usize],
628        no_kv_cache: bool,
629        non_granular_state: &Option<NonGranularState>,
630        context_lens: Vec<(usize, usize)>,
631        flash_params: &FlashParams,
632        flash_params_full: &FlashParams,
633    ) -> Result<Tensor> {
634        if self.xlora_classifier.is_some() {
635            let scalings = self.get_scalings(
636                input_ids,
637                input_ids_full,
638                seqlen_offsets,
639                seqlen_offsets_full,
640                no_kv_cache,
641                non_granular_state,
642                &vec![usize::MAX; context_lens.len()],
643                flash_params,
644                flash_params_full,
645            )?;
646
647            if no_kv_cache {
648                let mut res = self
649                    .inner_forward(
650                        input_ids_full,
651                        seqlen_offsets_full,
652                        Some(scalings),
653                        true,
654                        no_kv_cache,
655                        None,
656                        flash_params_full,
657                    )?
658                    .contiguous()?;
659                if let Some(t) = self.lm_head.quantized_act_type() {
660                    res = res.to_dtype(t)?;
661                }
662                extract_logits(
663                    &self.lm_head.lora_forward(&res, None, 1.0, None)?,
664                    context_lens,
665                )
666            } else {
667                // is_full_pass=true is ok because no_kv_cache=false
668                let mut res = self
669                    .inner_forward(
670                        input_ids,
671                        seqlen_offsets,
672                        Some(scalings),
673                        true,
674                        no_kv_cache,
675                        None,
676                        flash_params,
677                    )?
678                    .contiguous()?;
679                if let Some(t) = self.lm_head.quantized_act_type() {
680                    res = res.to_dtype(t)?;
681                }
682                extract_logits(
683                    &self.lm_head.lora_forward(&res, None, 1.0, None)?,
684                    context_lens,
685                )
686            }
687        } else {
688            let mut res = self
689                .inner_forward(
690                    input_ids,
691                    seqlen_offsets,
692                    None,
693                    false,
694                    no_kv_cache,
695                    None,
696                    flash_params,
697                )?
698                .contiguous()?;
699            if let Some(t) = self.lm_head.quantized_act_type() {
700                res = res.to_dtype(t)?;
701            }
702            extract_logits(
703                &self.lm_head.lora_forward(&res, None, 1.0, None)?,
704                context_lens,
705            )
706        }
707    }
708}
709
710impl IsqModel for Model {
711    fn get_layers(
712        &mut self,
713    ) -> (
714        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
715        &dyn DeviceMapper,
716    ) {
717        let mut tensors = Vec::new();
718        tensors.push((Arc::get_mut(&mut self.lm_head).unwrap().quant_inner(), None));
719        for (i, layer) in self.layers.iter_mut().enumerate() {
720            tensors.push((
721                Arc::get_mut(&mut layer.self_attn.q_proj)
722                    .unwrap()
723                    .quant_inner(),
724                Some(i),
725            ));
726            tensors.push((
727                Arc::get_mut(&mut layer.self_attn.k_proj)
728                    .unwrap()
729                    .quant_inner(),
730                Some(i),
731            ));
732            tensors.push((
733                Arc::get_mut(&mut layer.self_attn.v_proj)
734                    .unwrap()
735                    .quant_inner(),
736                Some(i),
737            ));
738            tensors.push((
739                Arc::get_mut(&mut layer.self_attn.o_proj)
740                    .unwrap()
741                    .quant_inner(),
742                Some(i),
743            ));
744            tensors.push((
745                Arc::get_mut(&mut layer.mlp.c_fc).unwrap().quant_inner(),
746                Some(i),
747            ));
748            tensors.push((
749                Arc::get_mut(&mut layer.mlp.c_proj).unwrap().quant_inner(),
750                Some(i),
751            ));
752        }
753        (tensors, &*self.mapper)
754    }
755
756    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
757        panic!("Cannot generate UQFF for an adapter model.")
758    }
759}
760
761impl NormalModel for Model {
762    fn forward(
763        &self,
764        _input_ids: &Tensor,
765        _seqlen_offsets: &[usize],
766        _context_lens: Vec<(usize, usize)>,
767        _position_ids: Vec<usize>,
768        _metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
769        _flash_params: &FlashParams,
770    ) -> Result<Tensor> {
771        unimplemented!()
772    }
773    fn xlora_forward(
774        &self,
775        input_ids: &Tensor,
776        input_ids_full: &Tensor,
777        seqlen_offsets: &[usize],
778        seqlen_offsets_full: &[usize],
779        no_kv_cache: bool,
780        non_granular_state: &Option<crate::xlora_models::NonGranularState>,
781        context_lens: Vec<(usize, usize)>,
782        _position_ids: Vec<usize>,
783        flash_params: &FlashParams,
784        flash_params_full: &FlashParams,
785    ) -> Result<Tensor> {
786        self.forward(
787            input_ids,
788            input_ids_full,
789            seqlen_offsets,
790            seqlen_offsets_full,
791            no_kv_cache,
792            non_granular_state,
793            context_lens,
794            flash_params,
795            flash_params_full,
796        )
797    }
798    fn cache(&self) -> &EitherCache {
799        &self.cache
800    }
801    fn cache_mut(&mut self) -> &mut EitherCache {
802        &mut self.cache
803    }
804    fn device(&self) -> &Device {
805        &self.device
806    }
807    fn is_xlora(&self) -> bool {
808        false
809    }
810    fn max_seq_len(&self) -> usize {
811        self.max_seq_len
812    }
813    fn config(&self) -> &ModelConfigMetadata {
814        &self.cfg
815    }
816}
817
818impl ScalingsMaker for Model {
819    fn dtype(&self) -> DType {
820        self.dtype
821    }
822    fn get_cache(&self) -> &EitherCache {
823        &self.cache
824    }
825    fn get_classifier(&self) -> &XLoraClassifier {
826        self.xlora_classifier.as_ref().unwrap()
827    }
828    fn forward(
829        &self,
830        input_ids: &Tensor,
831        seqlen_offsets: &[usize],
832        scalings: Tensor,
833        is_full_pass: bool,
834        no_kv_cache: bool,
835        is_scaling_pass: Option<f64>,
836        _context_lens: &[usize],
837        flash_params: &FlashParams,
838    ) -> Result<Tensor> {
839        self.inner_forward(
840            input_ids,
841            seqlen_offsets,
842            Some(scalings),
843            is_full_pass,
844            no_kv_cache,
845            is_scaling_pass,
846            flash_params,
847        )
848    }
849}
850
851impl AnyMoeBaseModelMixin for Model {}