mistralrs_core/xlora_models/
gemma2.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{collections::HashMap, sync::Arc};
4
5use candle_core::{DType, Device, Module, Result, Tensor};
6use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
7use tqdm::Iter;
8use tracing::info;
9
10use crate::{
11    amoe::AnyMoeBaseModelMixin,
12    attention::SdpaParams,
13    device_map::DeviceMapper,
14    layers::{self, Activation, CausalMasker, RmsNorm, RotaryEmbedding, Sdpa},
15    lora::{linear_b, linear_no_bias, LinearLayerLike, LoraConfig},
16    models::gemma2::Config,
17    paged_attention::ModelConfigMetadata,
18    pipeline::{
19        extract_logits,
20        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
21        Cache, EitherCache, IsqModel, NormalLoadingMetadata, NormalModel,
22    },
23    utils::progress::NiceProgressBar,
24    Ordering,
25};
26
27use super::{classifier::XLoraClassifier, NonGranularState, ScalingsMaker, XLoraConfig};
28
29#[derive(Clone)]
30#[allow(clippy::upper_case_acronyms)]
31struct MLP {
32    gate_proj: Arc<dyn LinearLayerLike + Send + Sync>,
33    up_proj: Arc<dyn LinearLayerLike + Send + Sync>,
34    down_proj: Arc<dyn LinearLayerLike + Send + Sync>,
35    act_fn: Activation,
36}
37
38impl MLP {
39    #[allow(clippy::too_many_arguments)]
40    fn new(
41        cfg: &Config,
42        vb: ShardedVarBuilder,
43        lora_config: &[((String, String), LoraConfig)],
44        count: &mut usize,
45        ord: &Ordering,
46        mapper: &dyn DeviceMapper,
47        layer_idx: usize,
48        loading_isq: bool,
49        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
50    ) -> Result<Self> {
51        let hidden_sz = cfg.hidden_size;
52        let intermediate_sz = cfg.intermediate_size;
53        let gate_proj = linear_b(
54            hidden_sz,
55            intermediate_sz,
56            false,
57            mapper.set_device(layer_idx, vb.pp("gate_proj"), loading_isq),
58            mapper.set_device(layer_idx, vb.pp("gate_proj"), false),
59            lora_config,
60            count,
61            ord,
62            preload_adapters,
63        )?;
64        let up_proj = linear_b(
65            hidden_sz,
66            intermediate_sz,
67            false,
68            mapper.set_device(layer_idx, vb.pp("up_proj"), loading_isq),
69            mapper.set_device(layer_idx, vb.pp("up_proj"), false),
70            lora_config,
71            count,
72            ord,
73            preload_adapters,
74        )?;
75        let down_proj = linear_b(
76            intermediate_sz,
77            hidden_sz,
78            false,
79            mapper.set_device(layer_idx, vb.pp("down_proj"), loading_isq),
80            mapper.set_device(layer_idx, vb.pp("down_proj"), false),
81            lora_config,
82            count,
83            ord,
84            preload_adapters,
85        )?;
86        Ok(Self {
87            gate_proj,
88            up_proj,
89            down_proj,
90            act_fn: cfg.hidden_act()?,
91        })
92    }
93
94    fn forward(
95        &self,
96        xs: &Tensor,
97        scalings: Option<Tensor>,
98        global_scaling_weight: f64,
99        is_scaling_pass: Option<f64>,
100    ) -> Result<Tensor> {
101        let original_dtype = xs.dtype();
102        let mut xs = xs.clone();
103        if let Some(t) = self.gate_proj.quantized_act_type() {
104            xs = xs.to_dtype(t)?;
105        }
106        let lhs = self
107            .gate_proj
108            .lora_forward(
109                &xs,
110                scalings.clone(),
111                global_scaling_weight,
112                is_scaling_pass,
113            )?
114            .apply(&self.act_fn)?;
115        let rhs = self.up_proj.lora_forward(
116            &xs,
117            scalings.clone(),
118            global_scaling_weight,
119            is_scaling_pass,
120        )?;
121        let mut res = self.down_proj.lora_forward(
122            &(lhs * rhs)?,
123            scalings,
124            global_scaling_weight,
125            is_scaling_pass,
126        )?;
127        if self.gate_proj.quantized_act_type().is_some() {
128            res = res.to_dtype(original_dtype)?;
129        }
130        Ok(res)
131    }
132}
133
134struct Attention {
135    q_proj: Arc<dyn LinearLayerLike + Send + Sync>,
136    k_proj: Arc<dyn LinearLayerLike + Send + Sync>,
137    v_proj: Arc<dyn LinearLayerLike + Send + Sync>,
138    o_proj: Arc<dyn LinearLayerLike + Send + Sync>,
139    num_heads: usize,
140    num_kv_heads: usize,
141    head_dim: usize,
142    rotary_emb: Arc<RotaryEmbedding>,
143    use_sliding_window: bool,
144    sliding_window: Option<usize>,
145    sdpa_params: SdpaParams,
146}
147
148impl Attention {
149    #[allow(clippy::too_many_arguments)]
150    fn new(
151        rotary_emb: Arc<RotaryEmbedding>,
152        cfg: &Config,
153        vb: ShardedVarBuilder,
154        lora_config: &[((String, String), LoraConfig)],
155        count: &mut usize,
156        ord: &Ordering,
157        mapper: &dyn DeviceMapper,
158        layer_idx: usize,
159        loading_isq: bool,
160        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
161    ) -> Result<Self> {
162        let hidden_sz = cfg.hidden_size;
163        let num_heads = cfg.num_attention_heads;
164        let num_kv_heads = cfg.num_key_value_heads;
165        let head_dim = cfg.head_dim;
166        let bias = cfg.attention_bias;
167        let q_proj = linear_b(
168            hidden_sz,
169            num_heads * head_dim,
170            bias,
171            mapper.set_device(layer_idx, vb.pp("q_proj"), loading_isq),
172            mapper.set_device(layer_idx, vb.pp("q_proj"), false),
173            lora_config,
174            count,
175            ord,
176            preload_adapters,
177        )?;
178        let k_proj = linear_b(
179            hidden_sz,
180            num_kv_heads * head_dim,
181            bias,
182            mapper.set_device(layer_idx, vb.pp("k_proj"), loading_isq),
183            mapper.set_device(layer_idx, vb.pp("k_proj"), false),
184            lora_config,
185            count,
186            ord,
187            preload_adapters,
188        )?;
189        let v_proj = linear_b(
190            hidden_sz,
191            num_kv_heads * head_dim,
192            bias,
193            mapper.set_device(layer_idx, vb.pp("v_proj"), loading_isq),
194            mapper.set_device(layer_idx, vb.pp("v_proj"), false),
195            lora_config,
196            count,
197            ord,
198            preload_adapters,
199        )?;
200        let o_proj = linear_b(
201            num_heads * head_dim,
202            hidden_sz,
203            bias,
204            mapper.set_device(layer_idx, vb.pp("o_proj"), loading_isq),
205            mapper.set_device(layer_idx, vb.pp("o_proj"), false),
206            lora_config,
207            count,
208            ord,
209            preload_adapters,
210        )?;
211        let sliding_window = if layer_idx % 2 == 0 {
212            // ^ Order is SWA, global, SWA
213            Some(cfg.sliding_window)
214        } else {
215            None
216        };
217        Ok(Self {
218            q_proj,
219            k_proj,
220            v_proj,
221            o_proj,
222            num_heads,
223            num_kv_heads,
224            head_dim,
225            rotary_emb,
226            use_sliding_window: layer_idx % 2 == 0, // Order is SWA, global, SWA
227            sliding_window,
228            sdpa_params: SdpaParams {
229                n_kv_groups: num_heads / num_kv_heads,
230                softcap: cfg.attn_logit_softcapping.map(|x| x as f32),
231                softmax_scale: 1.0 / (cfg.query_pre_attn_scalar as f32).sqrt(),
232                sliding_window,
233            },
234        })
235    }
236
237    #[allow(clippy::too_many_arguments)]
238    fn forward(
239        &self,
240        xs: &Tensor,
241        attention_mask: Option<&Tensor>,
242        sliding_attention_mask: Option<&Tensor>,
243        seqlen_offsets: &[usize],
244        kv_cache: &mut Option<(Tensor, Tensor)>,
245        scalings: Option<Tensor>,
246        global_scaling_weight: f64,
247        is_scaling_pass: Option<f64>,
248        flash_params: &FlashParams,
249    ) -> Result<Tensor> {
250        let (b_sz, q_len, _) = xs.dims3()?;
251
252        let original_dtype = xs.dtype();
253        let mut xs = xs.clone();
254        if let Some(t) = self.q_proj.quantized_act_type() {
255            xs = xs.to_dtype(t)?;
256        }
257        let mut q = self.q_proj.lora_forward(
258            &xs,
259            scalings.clone(),
260            global_scaling_weight,
261            is_scaling_pass,
262        )?;
263        let mut k = self.k_proj.lora_forward(
264            &xs,
265            scalings.clone(),
266            global_scaling_weight,
267            is_scaling_pass,
268        )?;
269        let mut v = self.v_proj.lora_forward(
270            &xs,
271            scalings.clone(),
272            global_scaling_weight,
273            is_scaling_pass,
274        )?;
275        if self.q_proj.quantized_act_type().is_some() {
276            q = q.to_dtype(original_dtype)?;
277            k = k.to_dtype(original_dtype)?;
278            v = v.to_dtype(original_dtype)?;
279        }
280
281        let (q, k, v) = if q_len != 1 {
282            let q = q
283                .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
284                .transpose(1, 2)?;
285            let k = k
286                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
287                .transpose(1, 2)?;
288            let v = v
289                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
290                .transpose(1, 2)?;
291            (q, k, v)
292        } else {
293            let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
294            let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
295            let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
296            (q, k, v)
297        };
298
299        let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
300
301        let mask = if self.use_sliding_window {
302            sliding_attention_mask
303        } else {
304            attention_mask
305        };
306
307        // self.sliding_window is None if !self.use_sliding_window
308        let (k, v, mask) =
309            Cache::update_kv_cache_sliding_window(kv_cache, k, v, mask, self.sliding_window)?;
310
311        let mut attn_output = Sdpa.run_attention(
312            &q,
313            &k,
314            &v,
315            mask.as_ref(),
316            Some(flash_params),
317            &self.sdpa_params,
318        )?;
319
320        if let Some(t) = self.q_proj.quantized_act_type() {
321            attn_output = attn_output.to_dtype(t)?;
322        }
323        let mut res = self.o_proj.lora_forward(
324            &attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?,
325            scalings.clone(),
326            global_scaling_weight,
327            is_scaling_pass,
328        )?;
329        if self.q_proj.quantized_act_type().is_some() {
330            res = res.to_dtype(original_dtype)?;
331        }
332        Ok(res)
333    }
334}
335
336struct DecoderLayer {
337    self_attn: Attention,
338    mlp: MLP,
339    input_layernorm: RmsNorm,
340    post_attention_layernorm: RmsNorm,
341    pre_feedforward_layernorm: RmsNorm,
342    post_feedforward_layernorm: RmsNorm,
343}
344
345impl DecoderLayer {
346    #[allow(clippy::too_many_arguments)]
347    fn new(
348        rotary_emb: Arc<RotaryEmbedding>,
349        cfg: &Config,
350        vb: ShardedVarBuilder,
351        lora_config: &[((String, String), LoraConfig)],
352        count: &mut usize,
353        ord: &Ordering,
354        mapper: &dyn DeviceMapper,
355        layer_idx: usize,
356        loading_isq: bool,
357        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
358    ) -> Result<Self> {
359        let self_attn = Attention::new(
360            rotary_emb,
361            cfg,
362            vb.pp("self_attn"),
363            lora_config,
364            count,
365            ord,
366            mapper,
367            layer_idx,
368            loading_isq,
369            preload_adapters,
370        )?;
371        let mlp = MLP::new(
372            cfg,
373            vb.pp("mlp"),
374            lora_config,
375            count,
376            ord,
377            mapper,
378            layer_idx,
379            loading_isq,
380            preload_adapters,
381        )?;
382        let input_layernorm = RmsNorm::new_gemma(
383            cfg.hidden_size,
384            cfg.rms_norm_eps,
385            mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
386        )?;
387        let post_attention_layernorm = RmsNorm::new_gemma(
388            cfg.hidden_size,
389            cfg.rms_norm_eps,
390            mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
391        )?;
392        let pre_feedforward_layernorm = RmsNorm::new_gemma(
393            cfg.hidden_size,
394            cfg.rms_norm_eps,
395            mapper.set_device(layer_idx, vb.pp("pre_feedforward_layernorm"), false),
396        )?;
397        let post_feedforward_layernorm = RmsNorm::new_gemma(
398            cfg.hidden_size,
399            cfg.rms_norm_eps,
400            mapper.set_device(layer_idx, vb.pp("post_feedforward_layernorm"), false),
401        )?;
402        Ok(Self {
403            self_attn,
404            mlp,
405            input_layernorm,
406            post_attention_layernorm,
407            pre_feedforward_layernorm,
408            post_feedforward_layernorm,
409        })
410    }
411
412    #[allow(clippy::too_many_arguments)]
413    fn forward(
414        &self,
415        xs: &Tensor,
416        attention_mask: Option<&Tensor>,
417        sliding_attention_mask: Option<&Tensor>,
418        seqlen_offsets: &[usize],
419        kv_cache: &mut Option<(Tensor, Tensor)>,
420        scalings: Option<Tensor>,
421        global_scaling_weight: f64,
422        is_scaling_pass: Option<f64>,
423        flash_params: &FlashParams,
424    ) -> Result<Tensor> {
425        let residual = xs;
426        let xs = self.input_layernorm.forward(xs)?;
427        let xs = self
428            .self_attn
429            .forward(
430                &xs,
431                attention_mask,
432                sliding_attention_mask,
433                seqlen_offsets,
434                kv_cache,
435                scalings.clone(),
436                global_scaling_weight,
437                is_scaling_pass,
438                flash_params,
439            )?
440            .apply(&self.post_attention_layernorm)?;
441        let xs = (xs + residual)?;
442        let residual = &xs;
443        let xs = self
444            .mlp
445            .forward(
446                &xs.apply(&self.pre_feedforward_layernorm)?,
447                scalings,
448                global_scaling_weight,
449                is_scaling_pass,
450            )?
451            .apply(&self.post_feedforward_layernorm)?;
452        residual + xs
453    }
454}
455
456pub struct Model {
457    embed_tokens: candle_nn::Embedding,
458    layers: Vec<DecoderLayer>,
459    norm: RmsNorm,
460    lm_head: Arc<dyn LinearLayerLike + Send + Sync>,
461    hidden_size: usize,
462    device: Device,
463    cache: EitherCache,
464    max_seq_len: usize,
465    mapper: Box<dyn DeviceMapper + Send + Sync>,
466    sliding_window: usize,
467    final_logit_softcapping: Option<f64>,
468    xlora_classifier: Option<XLoraClassifier>,
469    dtype: DType,
470    cfg: ModelConfigMetadata,
471}
472
473impl Model {
474    #[allow(clippy::too_many_arguments)]
475    pub fn new(
476        cfg: &Config,
477        vb: ShardedVarBuilder,
478        lora_config: &[((String, String), LoraConfig)],
479        xlora_config: Option<XLoraConfig>,
480        xlora_ordering: Ordering,
481        is_gptx: bool,
482        normal_loading_metadata: NormalLoadingMetadata,
483        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
484    ) -> Result<Self> {
485        if let Some(ref quant_cfg) = &cfg.quantization_config {
486            tracing::info!(
487                "Using {} quantization: {}.",
488                quant_cfg.name(),
489                quant_cfg.get_bits_name(&vb)
490            );
491        }
492        let mapper = normal_loading_metadata.mapper;
493
494        let vb_m = vb.pp("model");
495        let embed_tokens = layers::embedding(
496            cfg.vocab_size,
497            cfg.hidden_size,
498            mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
499            &cfg.quantization_config,
500        )?;
501        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
502        let vb_l = vb_m.pp("layers");
503        let mut ropes = HashMap::new();
504        for layer_idx in 0..cfg.num_hidden_layers {
505            let device = mapper
506                .device_for(layer_idx, false)
507                .unwrap_or(&normal_loading_metadata.real_device);
508            ropes.insert(
509                device.location(),
510                Arc::new(RotaryEmbedding::new(
511                    cfg.rope_theta as f32,
512                    cfg.head_dim,
513                    cfg.max_position_embeddings,
514                    device,
515                    is_gptx,
516                    vb.dtype(),
517                )?),
518            );
519        }
520        let mut count = 0;
521        for layer_idx in NiceProgressBar::<_, 'b'>(
522            0..cfg.num_hidden_layers,
523            "Loading repeating layers",
524            &normal_loading_metadata.multi_progress,
525        ) {
526            let device = mapper
527                .device_for(layer_idx, false)
528                .unwrap_or(&normal_loading_metadata.real_device);
529            let rotary_emb = ropes
530                .get(&device.location())
531                .expect("No RoPE for device location!")
532                .clone();
533            let layer = DecoderLayer::new(
534                rotary_emb.clone(),
535                cfg,
536                vb_l.pp(layer_idx),
537                lora_config,
538                &mut count,
539                &xlora_ordering,
540                &*mapper,
541                layer_idx,
542                normal_loading_metadata.loading_isq,
543                preload_adapters,
544            )?;
545            layers.push(layer)
546        }
547        if xlora_config.is_none() && preload_adapters.is_none() {
548            // We are now a LoRA model so we must merge the weights
549            info!("Merging LoRA adapters.");
550            for layer in layers.iter_mut().tqdm() {
551                Arc::get_mut(&mut layer.self_attn.k_proj)
552                    .unwrap()
553                    .merge_weights()?;
554                Arc::get_mut(&mut layer.self_attn.o_proj)
555                    .unwrap()
556                    .merge_weights()?;
557                Arc::get_mut(&mut layer.self_attn.q_proj)
558                    .unwrap()
559                    .merge_weights()?;
560                Arc::get_mut(&mut layer.self_attn.v_proj)
561                    .unwrap()
562                    .merge_weights()?;
563
564                Arc::get_mut(&mut layer.mlp.down_proj)
565                    .unwrap()
566                    .merge_weights()?;
567                Arc::get_mut(&mut layer.mlp.gate_proj)
568                    .unwrap()
569                    .merge_weights()?;
570                Arc::get_mut(&mut layer.mlp.up_proj)
571                    .unwrap()
572                    .merge_weights()?;
573            }
574        }
575        let norm = RmsNorm::new_gemma(
576            cfg.hidden_size,
577            cfg.rms_norm_eps,
578            mapper.set_nm_device(vb_m.pp("norm"), false),
579        )?;
580
581        let lm_head = linear_no_bias(
582            embed_tokens.embeddings().dim(1)?,
583            embed_tokens.embeddings().dim(0)?,
584            mapper.set_nm_device(vb_m.pp("embed_tokens"), normal_loading_metadata.loading_isq),
585            mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
586            lora_config,
587            &mut count,
588            &xlora_ordering,
589            preload_adapters,
590        )?;
591        if xlora_config.is_some() && lm_head.is_lora() {
592            // This is why we can pass dummy values (..., None, 1.0, None)?
593            candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA.");
594        }
595
596        Ok(Self {
597            embed_tokens,
598            layers,
599            norm,
600            lm_head,
601            device: normal_loading_metadata.real_device,
602            hidden_size: cfg.hidden_size,
603            cache: EitherCache::Full(Cache::new(cfg.num_hidden_layers, true)),
604            max_seq_len: cfg.max_position_embeddings,
605            mapper,
606            sliding_window: cfg.sliding_window,
607            final_logit_softcapping: cfg.final_logit_softcapping,
608            dtype: vb.dtype(),
609            xlora_classifier: xlora_config.map(|xlora_config| {
610                XLoraClassifier::new(xlora_config, count, lora_config.len(), vb, false).unwrap()
611            }),
612            cfg: ModelConfigMetadata {
613                max_seq_len: cfg.max_position_embeddings,
614                num_layers: cfg.num_hidden_layers,
615                hidden_size: cfg.hidden_size,
616                num_kv_heads: cfg.num_key_value_heads,
617                num_attn_heads: cfg.num_attention_heads,
618                sliding_window: None,
619                k_head_dim: cfg.head_dim,
620                v_head_dim: cfg.head_dim,
621            },
622        })
623    }
624
625    #[allow(clippy::too_many_arguments)]
626    fn inner_forward(
627        &self,
628        input_ids: &Tensor,
629        seqlen_offsets: &[usize],
630        scalings: Option<Tensor>,
631        is_full_pass: bool,
632        no_kv_cache: bool,
633        is_scaling_pass: Option<f64>,
634        flash_params: &FlashParams,
635    ) -> Result<Tensor> {
636        let xs = self.embed_tokens.forward(input_ids)?;
637        let mut xs = (xs * (self.hidden_size as f64).sqrt())?;
638        let mut cache = if is_full_pass {
639            if no_kv_cache {
640                let mut new_cache = Vec::new();
641                for _ in 0..self.cache.full().xlora_lock().len() {
642                    new_cache.push(None);
643                }
644
645                self.cache.full().xlora_lock().clone_from(&new_cache);
646            }
647            self.cache.full().xlora_lock()
648        } else {
649            self.cache.full().lock()
650        };
651        let attention_mask = CausalMasker.make_causal_mask_matrix(
652            input_ids,
653            &*cache,
654            xs.dtype(),
655            self.cfg.num_attn_heads,
656        )?;
657        let sliding_attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
658            input_ids,
659            &*cache,
660            Some(self.sliding_window),
661            xs.dtype(),
662            self.cfg.num_attn_heads,
663        )?;
664        for (i, layer) in self.layers.iter().enumerate() {
665            xs = self.mapper.map(xs, i)?;
666            xs = layer.forward(
667                &xs,
668                attention_mask
669                    .as_ref()
670                    .map(|m| m.to_device(xs.device()).unwrap())
671                    .as_ref(),
672                sliding_attention_mask
673                    .as_ref()
674                    .map(|m| m.to_device(xs.device()).unwrap())
675                    .as_ref(),
676                seqlen_offsets,
677                &mut cache[i],
678                scalings.clone(),
679                self.xlora_classifier
680                    .as_ref()
681                    .map(|classifier| classifier.get_global_scaling_weight())
682                    .unwrap_or(1.0),
683                is_scaling_pass,
684                flash_params,
685            )?;
686        }
687        let xs = xs.to_device(&self.device)?;
688        let mut xs = xs.apply(&self.norm)?;
689        if let Some(t) = self.lm_head.quantized_act_type() {
690            xs = xs.to_dtype(t)?;
691        }
692
693        let mut xs = self.lm_head.lora_forward(&xs, None, 1.0, None)?;
694
695        if let Some(final_logit_softcapping) = self.final_logit_softcapping {
696            xs = (xs / final_logit_softcapping)?;
697            xs = xs.tanh()?;
698            xs = (xs * final_logit_softcapping)?;
699        }
700
701        Ok(xs)
702    }
703
704    #[allow(clippy::too_many_arguments)]
705    pub fn forward(
706        &self,
707        input_ids: &Tensor,
708        input_ids_full: &Tensor,
709        seqlen_offsets: &[usize],
710        seqlen_offsets_full: &[usize],
711        no_kv_cache: bool,
712        non_granular_state: &Option<NonGranularState>,
713        context_lens: Vec<(usize, usize)>,
714        flash_params: &FlashParams,
715        flash_params_full: &FlashParams,
716    ) -> Result<Tensor> {
717        if self.xlora_classifier.is_some() {
718            let scalings = self.get_scalings(
719                input_ids,
720                input_ids_full,
721                seqlen_offsets,
722                seqlen_offsets_full,
723                no_kv_cache,
724                non_granular_state,
725                &vec![usize::MAX; context_lens.len()],
726                flash_params,
727                flash_params_full,
728            )?;
729
730            if no_kv_cache {
731                let mut res = self
732                    .inner_forward(
733                        input_ids_full,
734                        seqlen_offsets_full,
735                        Some(scalings),
736                        true,
737                        no_kv_cache,
738                        None,
739                        flash_params_full,
740                    )?
741                    .contiguous()?;
742                if let Some(t) = self.lm_head.quantized_act_type() {
743                    res = res.to_dtype(t)?;
744                }
745                extract_logits(
746                    &self.lm_head.lora_forward(&res, None, 1.0, None)?,
747                    context_lens,
748                )
749            } else {
750                // is_full_pass=true is ok because no_kv_cache=false
751                let mut res = self
752                    .inner_forward(
753                        input_ids,
754                        seqlen_offsets,
755                        Some(scalings),
756                        true,
757                        no_kv_cache,
758                        None,
759                        flash_params,
760                    )?
761                    .contiguous()?;
762                if let Some(t) = self.lm_head.quantized_act_type() {
763                    res = res.to_dtype(t)?;
764                }
765                extract_logits(
766                    &self.lm_head.lora_forward(&res, None, 1.0, None)?,
767                    context_lens,
768                )
769            }
770        } else {
771            let mut res = self
772                .inner_forward(
773                    input_ids,
774                    seqlen_offsets,
775                    None,
776                    false,
777                    no_kv_cache,
778                    None,
779                    flash_params,
780                )?
781                .contiguous()?;
782            if let Some(t) = self.lm_head.quantized_act_type() {
783                res = res.to_dtype(t)?;
784            }
785            extract_logits(
786                &self.lm_head.lora_forward(&res, None, 1.0, None)?,
787                context_lens,
788            )
789        }
790    }
791}
792
793impl IsqModel for Model {
794    fn get_layers(
795        &mut self,
796    ) -> (
797        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
798        &dyn DeviceMapper,
799    ) {
800        let mut tensors = Vec::new();
801        tensors.push((Arc::get_mut(&mut self.lm_head).unwrap().quant_inner(), None));
802        for (i, layer) in self.layers.iter_mut().enumerate() {
803            tensors.push((
804                Arc::get_mut(&mut layer.self_attn.q_proj)
805                    .unwrap()
806                    .quant_inner(),
807                Some(i),
808            ));
809            tensors.push((
810                Arc::get_mut(&mut layer.self_attn.k_proj)
811                    .unwrap()
812                    .quant_inner(),
813                Some(i),
814            ));
815            tensors.push((
816                Arc::get_mut(&mut layer.self_attn.v_proj)
817                    .unwrap()
818                    .quant_inner(),
819                Some(i),
820            ));
821            tensors.push((
822                Arc::get_mut(&mut layer.self_attn.o_proj)
823                    .unwrap()
824                    .quant_inner(),
825                Some(i),
826            ));
827            tensors.push((
828                Arc::get_mut(&mut layer.mlp.gate_proj)
829                    .unwrap()
830                    .quant_inner(),
831                Some(i),
832            ));
833            tensors.push((
834                Arc::get_mut(&mut layer.mlp.up_proj).unwrap().quant_inner(),
835                Some(i),
836            ));
837            tensors.push((
838                Arc::get_mut(&mut layer.mlp.down_proj)
839                    .unwrap()
840                    .quant_inner(),
841                Some(i),
842            ));
843        }
844        (tensors, &*self.mapper)
845    }
846
847    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
848        panic!("Cannot generate UQFF for an adapter model.")
849    }
850}
851
852impl NormalModel for Model {
853    fn forward(
854        &self,
855        _input_ids: &Tensor,
856        _seqlen_offsets: &[usize],
857        _context_lens: Vec<(usize, usize)>,
858        _position_ids: Vec<usize>,
859        _metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
860        _flash_params: &FlashParams,
861    ) -> Result<Tensor> {
862        unreachable!()
863    }
864    fn xlora_forward(
865        &self,
866        input_ids: &Tensor,
867        input_ids_full: &Tensor,
868        seqlen_offsets: &[usize],
869        seqlen_offsets_full: &[usize],
870        no_kv_cache: bool,
871        non_granular_state: &Option<crate::xlora_models::NonGranularState>,
872        context_lens: Vec<(usize, usize)>,
873        _position_ids: Vec<usize>,
874        flash_params: &FlashParams,
875        flash_params_full: &FlashParams,
876    ) -> Result<Tensor> {
877        self.forward(
878            input_ids,
879            input_ids_full,
880            seqlen_offsets,
881            seqlen_offsets_full,
882            no_kv_cache,
883            non_granular_state,
884            context_lens,
885            flash_params,
886            flash_params_full,
887        )
888    }
889    fn cache(&self) -> &EitherCache {
890        &self.cache
891    }
892    fn cache_mut(&mut self) -> &mut EitherCache {
893        &mut self.cache
894    }
895    fn device(&self) -> &Device {
896        &self.device
897    }
898    fn is_xlora(&self) -> bool {
899        false
900    }
901    fn max_seq_len(&self) -> usize {
902        self.max_seq_len
903    }
904    fn config(&self) -> &ModelConfigMetadata {
905        &self.cfg
906    }
907}
908
909impl ScalingsMaker for Model {
910    fn dtype(&self) -> DType {
911        self.dtype
912    }
913    fn get_cache(&self) -> &EitherCache {
914        &self.cache
915    }
916    fn get_classifier(&self) -> &XLoraClassifier {
917        self.xlora_classifier.as_ref().unwrap()
918    }
919    fn forward(
920        &self,
921        input_ids: &Tensor,
922        seqlen_offsets: &[usize],
923        scalings: Tensor,
924        is_full_pass: bool,
925        no_kv_cache: bool,
926        is_scaling_pass: Option<f64>,
927        _context_lens: &[usize],
928        flash_params: &FlashParams,
929    ) -> Result<Tensor> {
930        self.inner_forward(
931            input_ids,
932            seqlen_offsets,
933            Some(scalings),
934            is_full_pass,
935            no_kv_cache,
936            is_scaling_pass,
937            flash_params,
938        )
939    }
940}
941
942impl AnyMoeBaseModelMixin for Model {}