mistralrs_core/xlora_models/
gemma2.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{collections::HashMap, sync::Arc};
4
5use candle_core::{DType, Device, Module, Result, Tensor};
6use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
7use tqdm::Iter;
8use tracing::info;
9
10use crate::{
11    amoe::AnyMoeBaseModelMixin,
12    attention::SdpaParams,
13    device_map::DeviceMapper,
14    layers::{self, Activation, CausalMasker, RmsNorm, RotaryEmbedding, Sdpa},
15    lora::{linear_b, linear_no_bias, LinearLayerLike, LoraConfig},
16    models::gemma2::Config,
17    paged_attention::ModelConfigMetadata,
18    pipeline::{
19        extract_logits,
20        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
21        Cache, EitherCache, IsqModel, NormalLoadingMetadata, NormalModel,
22    },
23    utils::progress::NiceProgressBar,
24    Ordering,
25};
26
27use super::{classifier::XLoraClassifier, NonGranularState, ScalingsMaker, XLoraConfig};
28
29#[derive(Clone)]
30#[allow(clippy::upper_case_acronyms)]
31struct MLP {
32    gate_proj: Arc<dyn LinearLayerLike + Send + Sync>,
33    up_proj: Arc<dyn LinearLayerLike + Send + Sync>,
34    down_proj: Arc<dyn LinearLayerLike + Send + Sync>,
35    act_fn: Activation,
36}
37
38impl MLP {
39    #[allow(clippy::too_many_arguments)]
40    fn new(
41        cfg: &Config,
42        vb: ShardedVarBuilder,
43        lora_config: &[((String, String), LoraConfig)],
44        count: &mut usize,
45        ord: &Ordering,
46        mapper: &dyn DeviceMapper,
47        layer_idx: usize,
48        loading_isq: bool,
49        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
50    ) -> Result<Self> {
51        let hidden_sz = cfg.hidden_size;
52        let intermediate_sz = cfg.intermediate_size;
53        let gate_proj = linear_b(
54            hidden_sz,
55            intermediate_sz,
56            false,
57            mapper.set_device(layer_idx, vb.pp("gate_proj"), loading_isq),
58            mapper.set_device(layer_idx, vb.pp("gate_proj"), false),
59            lora_config,
60            count,
61            ord,
62            preload_adapters,
63        )?;
64        let up_proj = linear_b(
65            hidden_sz,
66            intermediate_sz,
67            false,
68            mapper.set_device(layer_idx, vb.pp("up_proj"), loading_isq),
69            mapper.set_device(layer_idx, vb.pp("up_proj"), false),
70            lora_config,
71            count,
72            ord,
73            preload_adapters,
74        )?;
75        let down_proj = linear_b(
76            intermediate_sz,
77            hidden_sz,
78            false,
79            mapper.set_device(layer_idx, vb.pp("down_proj"), loading_isq),
80            mapper.set_device(layer_idx, vb.pp("down_proj"), false),
81            lora_config,
82            count,
83            ord,
84            preload_adapters,
85        )?;
86        Ok(Self {
87            gate_proj,
88            up_proj,
89            down_proj,
90            act_fn: cfg.hidden_act()?,
91        })
92    }
93
94    fn forward(
95        &self,
96        xs: &Tensor,
97        scalings: Option<Tensor>,
98        global_scaling_weight: f64,
99        is_scaling_pass: Option<f64>,
100    ) -> Result<Tensor> {
101        let original_dtype = xs.dtype();
102        let mut xs = xs.clone();
103        if let Some(t) = self.gate_proj.quantized_act_type() {
104            xs = xs.to_dtype(t)?;
105        }
106        let lhs = self
107            .gate_proj
108            .lora_forward(
109                &xs,
110                scalings.clone(),
111                global_scaling_weight,
112                is_scaling_pass,
113            )?
114            .apply(&self.act_fn)?;
115        let rhs = self.up_proj.lora_forward(
116            &xs,
117            scalings.clone(),
118            global_scaling_weight,
119            is_scaling_pass,
120        )?;
121        let mut res = self.down_proj.lora_forward(
122            &(lhs * rhs)?,
123            scalings,
124            global_scaling_weight,
125            is_scaling_pass,
126        )?;
127        if self.gate_proj.quantized_act_type().is_some() {
128            res = res.to_dtype(original_dtype)?;
129        }
130        Ok(res)
131    }
132}
133
134struct Attention {
135    q_proj: Arc<dyn LinearLayerLike + Send + Sync>,
136    k_proj: Arc<dyn LinearLayerLike + Send + Sync>,
137    v_proj: Arc<dyn LinearLayerLike + Send + Sync>,
138    o_proj: Arc<dyn LinearLayerLike + Send + Sync>,
139    num_heads: usize,
140    num_kv_heads: usize,
141    head_dim: usize,
142    rotary_emb: Arc<RotaryEmbedding>,
143    use_sliding_window: bool,
144    sliding_window: Option<usize>,
145    sdpa_params: SdpaParams,
146}
147
148impl Attention {
149    #[allow(clippy::too_many_arguments)]
150    fn new(
151        rotary_emb: Arc<RotaryEmbedding>,
152        cfg: &Config,
153        vb: ShardedVarBuilder,
154        lora_config: &[((String, String), LoraConfig)],
155        count: &mut usize,
156        ord: &Ordering,
157        mapper: &dyn DeviceMapper,
158        layer_idx: usize,
159        loading_isq: bool,
160        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
161    ) -> Result<Self> {
162        let hidden_sz = cfg.hidden_size;
163        let num_heads = cfg.num_attention_heads;
164        let num_kv_heads = cfg.num_key_value_heads;
165        let head_dim = cfg.head_dim;
166        let bias = cfg.attention_bias;
167        let q_proj = linear_b(
168            hidden_sz,
169            num_heads * head_dim,
170            bias,
171            mapper.set_device(layer_idx, vb.pp("q_proj"), loading_isq),
172            mapper.set_device(layer_idx, vb.pp("q_proj"), false),
173            lora_config,
174            count,
175            ord,
176            preload_adapters,
177        )?;
178        let k_proj = linear_b(
179            hidden_sz,
180            num_kv_heads * head_dim,
181            bias,
182            mapper.set_device(layer_idx, vb.pp("k_proj"), loading_isq),
183            mapper.set_device(layer_idx, vb.pp("k_proj"), false),
184            lora_config,
185            count,
186            ord,
187            preload_adapters,
188        )?;
189        let v_proj = linear_b(
190            hidden_sz,
191            num_kv_heads * head_dim,
192            bias,
193            mapper.set_device(layer_idx, vb.pp("v_proj"), loading_isq),
194            mapper.set_device(layer_idx, vb.pp("v_proj"), false),
195            lora_config,
196            count,
197            ord,
198            preload_adapters,
199        )?;
200        let o_proj = linear_b(
201            num_heads * head_dim,
202            hidden_sz,
203            bias,
204            mapper.set_device(layer_idx, vb.pp("o_proj"), loading_isq),
205            mapper.set_device(layer_idx, vb.pp("o_proj"), false),
206            lora_config,
207            count,
208            ord,
209            preload_adapters,
210        )?;
211        let sliding_window = if layer_idx % 2 == 0 {
212            // ^ Order is SWA, global, SWA
213            Some(cfg.sliding_window)
214        } else {
215            None
216        };
217        Ok(Self {
218            q_proj,
219            k_proj,
220            v_proj,
221            o_proj,
222            num_heads,
223            num_kv_heads,
224            head_dim,
225            rotary_emb,
226            use_sliding_window: layer_idx % 2 == 0, // Order is SWA, global, SWA
227            sliding_window,
228            sdpa_params: SdpaParams {
229                n_kv_groups: num_heads / num_kv_heads,
230                use_flash_attn: cfg.use_flash_attn,
231                softcap: cfg.attn_logit_softcapping.map(|x| x as f32),
232                softmax_scale: 1.0 / (cfg.query_pre_attn_scalar as f32).sqrt(),
233                sliding_window,
234            },
235        })
236    }
237
238    #[allow(clippy::too_many_arguments)]
239    fn forward(
240        &self,
241        xs: &Tensor,
242        attention_mask: Option<&Tensor>,
243        sliding_attention_mask: Option<&Tensor>,
244        seqlen_offsets: &[usize],
245        kv_cache: &mut Option<(Tensor, Tensor)>,
246        scalings: Option<Tensor>,
247        global_scaling_weight: f64,
248        is_scaling_pass: Option<f64>,
249        flash_params: &FlashParams,
250    ) -> Result<Tensor> {
251        let (b_sz, q_len, _) = xs.dims3()?;
252
253        let original_dtype = xs.dtype();
254        let mut xs = xs.clone();
255        if let Some(t) = self.q_proj.quantized_act_type() {
256            xs = xs.to_dtype(t)?;
257        }
258        let mut q = self.q_proj.lora_forward(
259            &xs,
260            scalings.clone(),
261            global_scaling_weight,
262            is_scaling_pass,
263        )?;
264        let mut k = self.k_proj.lora_forward(
265            &xs,
266            scalings.clone(),
267            global_scaling_weight,
268            is_scaling_pass,
269        )?;
270        let mut v = self.v_proj.lora_forward(
271            &xs,
272            scalings.clone(),
273            global_scaling_weight,
274            is_scaling_pass,
275        )?;
276        if self.q_proj.quantized_act_type().is_some() {
277            q = q.to_dtype(original_dtype)?;
278            k = k.to_dtype(original_dtype)?;
279            v = v.to_dtype(original_dtype)?;
280        }
281
282        let (q, k, v) = if q_len != 1 {
283            let q = q
284                .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
285                .transpose(1, 2)?;
286            let k = k
287                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
288                .transpose(1, 2)?;
289            let v = v
290                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
291                .transpose(1, 2)?;
292            (q, k, v)
293        } else {
294            let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
295            let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
296            let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
297            (q, k, v)
298        };
299
300        let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
301
302        let mask = if self.use_sliding_window {
303            sliding_attention_mask
304        } else {
305            attention_mask
306        };
307
308        // self.sliding_window is None if !self.use_sliding_window
309        let (k, v, mask) = Cache::update_kv_cache_sliding_window(
310            kv_cache,
311            k,
312            v,
313            mask,
314            self.sliding_window,
315            false,
316        )?;
317
318        let mut attn_output = Sdpa.run_attention(
319            &q,
320            &k,
321            &v,
322            mask.as_ref(),
323            Some(flash_params),
324            &self.sdpa_params,
325        )?;
326
327        if let Some(t) = self.q_proj.quantized_act_type() {
328            attn_output = attn_output.to_dtype(t)?;
329        }
330        let mut res = self.o_proj.lora_forward(
331            &attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?,
332            scalings.clone(),
333            global_scaling_weight,
334            is_scaling_pass,
335        )?;
336        if self.q_proj.quantized_act_type().is_some() {
337            res = res.to_dtype(original_dtype)?;
338        }
339        Ok(res)
340    }
341}
342
343struct DecoderLayer {
344    self_attn: Attention,
345    mlp: MLP,
346    input_layernorm: RmsNorm,
347    post_attention_layernorm: RmsNorm,
348    pre_feedforward_layernorm: RmsNorm,
349    post_feedforward_layernorm: RmsNorm,
350}
351
352impl DecoderLayer {
353    #[allow(clippy::too_many_arguments)]
354    fn new(
355        rotary_emb: Arc<RotaryEmbedding>,
356        cfg: &Config,
357        vb: ShardedVarBuilder,
358        lora_config: &[((String, String), LoraConfig)],
359        count: &mut usize,
360        ord: &Ordering,
361        mapper: &dyn DeviceMapper,
362        layer_idx: usize,
363        loading_isq: bool,
364        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
365    ) -> Result<Self> {
366        let self_attn = Attention::new(
367            rotary_emb,
368            cfg,
369            vb.pp("self_attn"),
370            lora_config,
371            count,
372            ord,
373            mapper,
374            layer_idx,
375            loading_isq,
376            preload_adapters,
377        )?;
378        let mlp = MLP::new(
379            cfg,
380            vb.pp("mlp"),
381            lora_config,
382            count,
383            ord,
384            mapper,
385            layer_idx,
386            loading_isq,
387            preload_adapters,
388        )?;
389        let input_layernorm = RmsNorm::new_gemma(
390            cfg.hidden_size,
391            cfg.rms_norm_eps,
392            mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
393        )?;
394        let post_attention_layernorm = RmsNorm::new_gemma(
395            cfg.hidden_size,
396            cfg.rms_norm_eps,
397            mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
398        )?;
399        let pre_feedforward_layernorm = RmsNorm::new_gemma(
400            cfg.hidden_size,
401            cfg.rms_norm_eps,
402            mapper.set_device(layer_idx, vb.pp("pre_feedforward_layernorm"), false),
403        )?;
404        let post_feedforward_layernorm = RmsNorm::new_gemma(
405            cfg.hidden_size,
406            cfg.rms_norm_eps,
407            mapper.set_device(layer_idx, vb.pp("post_feedforward_layernorm"), false),
408        )?;
409        Ok(Self {
410            self_attn,
411            mlp,
412            input_layernorm,
413            post_attention_layernorm,
414            pre_feedforward_layernorm,
415            post_feedforward_layernorm,
416        })
417    }
418
419    #[allow(clippy::too_many_arguments)]
420    fn forward(
421        &self,
422        xs: &Tensor,
423        attention_mask: Option<&Tensor>,
424        sliding_attention_mask: Option<&Tensor>,
425        seqlen_offsets: &[usize],
426        kv_cache: &mut Option<(Tensor, Tensor)>,
427        scalings: Option<Tensor>,
428        global_scaling_weight: f64,
429        is_scaling_pass: Option<f64>,
430        flash_params: &FlashParams,
431    ) -> Result<Tensor> {
432        let residual = xs;
433        let xs = self.input_layernorm.forward(xs)?;
434        let xs = self
435            .self_attn
436            .forward(
437                &xs,
438                attention_mask,
439                sliding_attention_mask,
440                seqlen_offsets,
441                kv_cache,
442                scalings.clone(),
443                global_scaling_weight,
444                is_scaling_pass,
445                flash_params,
446            )?
447            .apply(&self.post_attention_layernorm)?;
448        let xs = (xs + residual)?;
449        let residual = &xs;
450        let xs = self
451            .mlp
452            .forward(
453                &xs.apply(&self.pre_feedforward_layernorm)?,
454                scalings,
455                global_scaling_weight,
456                is_scaling_pass,
457            )?
458            .apply(&self.post_feedforward_layernorm)?;
459        residual + xs
460    }
461}
462
463pub struct Model {
464    embed_tokens: candle_nn::Embedding,
465    layers: Vec<DecoderLayer>,
466    norm: RmsNorm,
467    lm_head: Arc<dyn LinearLayerLike + Send + Sync>,
468    hidden_size: usize,
469    device: Device,
470    cache: EitherCache,
471    max_seq_len: usize,
472    mapper: Box<dyn DeviceMapper + Send + Sync>,
473    sliding_window: usize,
474    final_logit_softcapping: Option<f64>,
475    xlora_classifier: Option<XLoraClassifier>,
476    dtype: DType,
477    cfg: ModelConfigMetadata,
478}
479
480impl Model {
481    #[allow(clippy::too_many_arguments)]
482    pub fn new(
483        cfg: &Config,
484        vb: ShardedVarBuilder,
485        lora_config: &[((String, String), LoraConfig)],
486        xlora_config: Option<XLoraConfig>,
487        xlora_ordering: Ordering,
488        is_gptx: bool,
489        normal_loading_metadata: NormalLoadingMetadata,
490        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
491    ) -> Result<Self> {
492        if let Some(ref quant_cfg) = &cfg.quantization_config {
493            tracing::info!(
494                "Using {} quantization: {}.",
495                quant_cfg.name(),
496                quant_cfg.get_bits_name(&vb)
497            );
498        }
499        let mapper = normal_loading_metadata.mapper;
500
501        let vb_m = vb.pp("model");
502        let embed_tokens = layers::embedding(
503            cfg.vocab_size,
504            cfg.hidden_size,
505            mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
506            &cfg.quantization_config,
507        )?;
508        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
509        let vb_l = vb_m.pp("layers");
510        let mut ropes = HashMap::new();
511        for layer_idx in 0..cfg.num_hidden_layers {
512            let device = mapper
513                .device_for(layer_idx, false)
514                .unwrap_or(&normal_loading_metadata.real_device);
515            ropes.insert(
516                device.location(),
517                Arc::new(RotaryEmbedding::new(
518                    cfg.rope_theta as f32,
519                    cfg.head_dim,
520                    cfg.max_position_embeddings,
521                    device,
522                    is_gptx,
523                    vb.dtype(),
524                )?),
525            );
526        }
527        let mut count = 0;
528        for layer_idx in NiceProgressBar::<_, 'b'>(
529            0..cfg.num_hidden_layers,
530            "Loading repeating layers",
531            &normal_loading_metadata.multi_progress,
532        ) {
533            let device = mapper
534                .device_for(layer_idx, false)
535                .unwrap_or(&normal_loading_metadata.real_device);
536            let rotary_emb = ropes
537                .get(&device.location())
538                .expect("No RoPE for device location!")
539                .clone();
540            let layer = DecoderLayer::new(
541                rotary_emb.clone(),
542                cfg,
543                vb_l.pp(layer_idx),
544                lora_config,
545                &mut count,
546                &xlora_ordering,
547                &*mapper,
548                layer_idx,
549                normal_loading_metadata.loading_isq,
550                preload_adapters,
551            )?;
552            layers.push(layer)
553        }
554        if xlora_config.is_none() && preload_adapters.is_none() {
555            // We are now a LoRA model so we must merge the weights
556            info!("Merging LoRA adapters.");
557            for layer in layers.iter_mut().tqdm() {
558                Arc::get_mut(&mut layer.self_attn.k_proj)
559                    .unwrap()
560                    .merge_weights()?;
561                Arc::get_mut(&mut layer.self_attn.o_proj)
562                    .unwrap()
563                    .merge_weights()?;
564                Arc::get_mut(&mut layer.self_attn.q_proj)
565                    .unwrap()
566                    .merge_weights()?;
567                Arc::get_mut(&mut layer.self_attn.v_proj)
568                    .unwrap()
569                    .merge_weights()?;
570
571                Arc::get_mut(&mut layer.mlp.down_proj)
572                    .unwrap()
573                    .merge_weights()?;
574                Arc::get_mut(&mut layer.mlp.gate_proj)
575                    .unwrap()
576                    .merge_weights()?;
577                Arc::get_mut(&mut layer.mlp.up_proj)
578                    .unwrap()
579                    .merge_weights()?;
580            }
581        }
582        let norm = RmsNorm::new_gemma(
583            cfg.hidden_size,
584            cfg.rms_norm_eps,
585            mapper.set_nm_device(vb_m.pp("norm"), false),
586        )?;
587
588        let lm_head = linear_no_bias(
589            embed_tokens.embeddings().dim(1)?,
590            embed_tokens.embeddings().dim(0)?,
591            mapper.set_nm_device(vb_m.pp("embed_tokens"), normal_loading_metadata.loading_isq),
592            mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
593            lora_config,
594            &mut count,
595            &xlora_ordering,
596            preload_adapters,
597        )?;
598        if xlora_config.is_some() && lm_head.is_lora() {
599            // This is why we can pass dummy values (..., None, 1.0, None)?
600            candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA.");
601        }
602
603        Ok(Self {
604            embed_tokens,
605            layers,
606            norm,
607            lm_head,
608            device: normal_loading_metadata.real_device,
609            hidden_size: cfg.hidden_size,
610            cache: EitherCache::Full(Cache::new(cfg.num_hidden_layers, true)),
611            max_seq_len: cfg.max_position_embeddings,
612            mapper,
613            sliding_window: cfg.sliding_window,
614            final_logit_softcapping: cfg.final_logit_softcapping,
615            dtype: vb.dtype(),
616            xlora_classifier: xlora_config.map(|xlora_config| {
617                XLoraClassifier::new(xlora_config, count, lora_config.len(), vb, false).unwrap()
618            }),
619            cfg: ModelConfigMetadata {
620                max_seq_len: cfg.max_position_embeddings,
621                num_layers: cfg.num_hidden_layers,
622                hidden_size: cfg.hidden_size,
623                num_kv_heads: cfg.num_key_value_heads,
624                num_attn_heads: cfg.num_attention_heads,
625                sliding_window: None,
626                k_head_dim: cfg.head_dim,
627                v_head_dim: cfg.head_dim,
628            },
629        })
630    }
631
632    #[allow(clippy::too_many_arguments)]
633    fn inner_forward(
634        &self,
635        input_ids: &Tensor,
636        seqlen_offsets: &[usize],
637        scalings: Option<Tensor>,
638        is_full_pass: bool,
639        no_kv_cache: bool,
640        is_scaling_pass: Option<f64>,
641        flash_params: &FlashParams,
642    ) -> Result<Tensor> {
643        let xs = self.embed_tokens.forward(input_ids)?;
644        let mut xs = (xs * (self.hidden_size as f64).sqrt())?;
645        let mut cache = if is_full_pass {
646            if no_kv_cache {
647                let mut new_cache = Vec::new();
648                for _ in 0..self.cache.full().xlora_lock().len() {
649                    new_cache.push(None);
650                }
651
652                self.cache.full().xlora_lock().clone_from(&new_cache);
653            }
654            self.cache.full().xlora_lock()
655        } else {
656            self.cache.full().lock()
657        };
658        let attention_mask = CausalMasker.make_causal_mask_matrix(
659            input_ids,
660            &*cache,
661            xs.dtype(),
662            self.cfg.num_attn_heads,
663        )?;
664        let sliding_attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
665            input_ids,
666            &*cache,
667            Some(self.sliding_window),
668            xs.dtype(),
669            self.cfg.num_attn_heads,
670        )?;
671        for (i, layer) in self.layers.iter().enumerate() {
672            xs = self.mapper.map(xs, i)?;
673            xs = layer.forward(
674                &xs,
675                attention_mask
676                    .as_ref()
677                    .map(|m| m.to_device(xs.device()).unwrap())
678                    .as_ref(),
679                sliding_attention_mask
680                    .as_ref()
681                    .map(|m| m.to_device(xs.device()).unwrap())
682                    .as_ref(),
683                seqlen_offsets,
684                &mut cache[i],
685                scalings.clone(),
686                self.xlora_classifier
687                    .as_ref()
688                    .map(|classifier| classifier.get_global_scaling_weight())
689                    .unwrap_or(1.0),
690                is_scaling_pass,
691                flash_params,
692            )?;
693        }
694        let xs = xs.to_device(&self.device)?;
695        let mut xs = xs.apply(&self.norm)?;
696        if let Some(t) = self.lm_head.quantized_act_type() {
697            xs = xs.to_dtype(t)?;
698        }
699
700        let mut xs = self.lm_head.lora_forward(&xs, None, 1.0, None)?;
701
702        if let Some(final_logit_softcapping) = self.final_logit_softcapping {
703            xs = (xs / final_logit_softcapping)?;
704            xs = xs.tanh()?;
705            xs = (xs * final_logit_softcapping)?;
706        }
707
708        Ok(xs)
709    }
710
711    #[allow(clippy::too_many_arguments)]
712    pub fn forward(
713        &self,
714        input_ids: &Tensor,
715        input_ids_full: &Tensor,
716        seqlen_offsets: &[usize],
717        seqlen_offsets_full: &[usize],
718        no_kv_cache: bool,
719        non_granular_state: &Option<NonGranularState>,
720        context_lens: Vec<(usize, usize)>,
721        flash_params: &FlashParams,
722        flash_params_full: &FlashParams,
723    ) -> Result<Tensor> {
724        if self.xlora_classifier.is_some() {
725            let scalings = self.get_scalings(
726                input_ids,
727                input_ids_full,
728                seqlen_offsets,
729                seqlen_offsets_full,
730                no_kv_cache,
731                non_granular_state,
732                &vec![usize::MAX; context_lens.len()],
733                flash_params,
734                flash_params_full,
735            )?;
736
737            if no_kv_cache {
738                let mut res = self
739                    .inner_forward(
740                        input_ids_full,
741                        seqlen_offsets_full,
742                        Some(scalings),
743                        true,
744                        no_kv_cache,
745                        None,
746                        flash_params_full,
747                    )?
748                    .contiguous()?;
749                if let Some(t) = self.lm_head.quantized_act_type() {
750                    res = res.to_dtype(t)?;
751                }
752                extract_logits(
753                    &self.lm_head.lora_forward(&res, None, 1.0, None)?,
754                    context_lens,
755                )
756            } else {
757                // is_full_pass=true is ok because no_kv_cache=false
758                let mut res = self
759                    .inner_forward(
760                        input_ids,
761                        seqlen_offsets,
762                        Some(scalings),
763                        true,
764                        no_kv_cache,
765                        None,
766                        flash_params,
767                    )?
768                    .contiguous()?;
769                if let Some(t) = self.lm_head.quantized_act_type() {
770                    res = res.to_dtype(t)?;
771                }
772                extract_logits(
773                    &self.lm_head.lora_forward(&res, None, 1.0, None)?,
774                    context_lens,
775                )
776            }
777        } else {
778            let mut res = self
779                .inner_forward(
780                    input_ids,
781                    seqlen_offsets,
782                    None,
783                    false,
784                    no_kv_cache,
785                    None,
786                    flash_params,
787                )?
788                .contiguous()?;
789            if let Some(t) = self.lm_head.quantized_act_type() {
790                res = res.to_dtype(t)?;
791            }
792            extract_logits(
793                &self.lm_head.lora_forward(&res, None, 1.0, None)?,
794                context_lens,
795            )
796        }
797    }
798}
799
800impl IsqModel for Model {
801    fn get_layers(
802        &mut self,
803    ) -> (
804        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
805        &dyn DeviceMapper,
806    ) {
807        let mut tensors = Vec::new();
808        tensors.push((Arc::get_mut(&mut self.lm_head).unwrap().quant_inner(), None));
809        for (i, layer) in self.layers.iter_mut().enumerate() {
810            tensors.push((
811                Arc::get_mut(&mut layer.self_attn.q_proj)
812                    .unwrap()
813                    .quant_inner(),
814                Some(i),
815            ));
816            tensors.push((
817                Arc::get_mut(&mut layer.self_attn.k_proj)
818                    .unwrap()
819                    .quant_inner(),
820                Some(i),
821            ));
822            tensors.push((
823                Arc::get_mut(&mut layer.self_attn.v_proj)
824                    .unwrap()
825                    .quant_inner(),
826                Some(i),
827            ));
828            tensors.push((
829                Arc::get_mut(&mut layer.self_attn.o_proj)
830                    .unwrap()
831                    .quant_inner(),
832                Some(i),
833            ));
834            tensors.push((
835                Arc::get_mut(&mut layer.mlp.gate_proj)
836                    .unwrap()
837                    .quant_inner(),
838                Some(i),
839            ));
840            tensors.push((
841                Arc::get_mut(&mut layer.mlp.up_proj).unwrap().quant_inner(),
842                Some(i),
843            ));
844            tensors.push((
845                Arc::get_mut(&mut layer.mlp.down_proj)
846                    .unwrap()
847                    .quant_inner(),
848                Some(i),
849            ));
850        }
851        (tensors, &*self.mapper)
852    }
853
854    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
855        panic!("Cannot generate UQFF for an adapter model.")
856    }
857}
858
859impl NormalModel for Model {
860    fn forward(
861        &self,
862        _input_ids: &Tensor,
863        _seqlen_offsets: &[usize],
864        _context_lens: Vec<(usize, usize)>,
865        _position_ids: Vec<usize>,
866        _metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
867        _flash_params: &FlashParams,
868    ) -> Result<Tensor> {
869        unreachable!()
870    }
871    fn xlora_forward(
872        &self,
873        input_ids: &Tensor,
874        input_ids_full: &Tensor,
875        seqlen_offsets: &[usize],
876        seqlen_offsets_full: &[usize],
877        no_kv_cache: bool,
878        non_granular_state: &Option<crate::xlora_models::NonGranularState>,
879        context_lens: Vec<(usize, usize)>,
880        _position_ids: Vec<usize>,
881        flash_params: &FlashParams,
882        flash_params_full: &FlashParams,
883    ) -> Result<Tensor> {
884        self.forward(
885            input_ids,
886            input_ids_full,
887            seqlen_offsets,
888            seqlen_offsets_full,
889            no_kv_cache,
890            non_granular_state,
891            context_lens,
892            flash_params,
893            flash_params_full,
894        )
895    }
896    fn cache(&self) -> &EitherCache {
897        &self.cache
898    }
899    fn cache_mut(&mut self) -> &mut EitherCache {
900        &mut self.cache
901    }
902    fn device(&self) -> &Device {
903        &self.device
904    }
905    fn is_xlora(&self) -> bool {
906        false
907    }
908    fn max_seq_len(&self) -> usize {
909        self.max_seq_len
910    }
911    fn config(&self) -> &ModelConfigMetadata {
912        &self.cfg
913    }
914}
915
916impl ScalingsMaker for Model {
917    fn dtype(&self) -> DType {
918        self.dtype
919    }
920    fn get_cache(&self) -> &EitherCache {
921        &self.cache
922    }
923    fn get_classifier(&self) -> &XLoraClassifier {
924        self.xlora_classifier.as_ref().unwrap()
925    }
926    fn forward(
927        &self,
928        input_ids: &Tensor,
929        seqlen_offsets: &[usize],
930        scalings: Tensor,
931        is_full_pass: bool,
932        no_kv_cache: bool,
933        is_scaling_pass: Option<f64>,
934        _context_lens: &[usize],
935        flash_params: &FlashParams,
936    ) -> Result<Tensor> {
937        self.inner_forward(
938            input_ids,
939            seqlen_offsets,
940            Some(scalings),
941            is_full_pass,
942            no_kv_cache,
943            is_scaling_pass,
944            flash_params,
945        )
946    }
947}
948
949impl AnyMoeBaseModelMixin for Model {}