mistralrs_core/vision_models/phi3/
mod.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3pub(crate) mod phi3_inputs_processor;
4
5// This implementation is based on:
6// https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/modeling_phi3.py
7use candle_core::{
8    shape::ShapeWithOneHole, DType, Device, IndexOp, Module, Result, Shape, Tensor, D,
9};
10use either::Either;
11use mistralrs_quant::{
12    BitWiseOp, NonZeroOp, QuantMethod, QuantizedConfig, ReplicatedLayer, ShardedVarBuilder,
13};
14use std::{any::Any, collections::HashMap, fmt::Debug, sync::Arc};
15
16use crate::{
17    amoe::{AnyMoeBaseModelMixin, AnyMoeTrainableLayer, MlpLayer, MoeMlp},
18    attention::SdpaParams,
19    device_map::DeviceMapper,
20    get_delta_from_lora_ab,
21    layers::{
22        self, Activation, CausalMasker, MatMul, PhiRopeConfig, PhiRopeScalingConfig,
23        PhiRotaryEmbedding, RmsNorm, Sdpa,
24    },
25    layers_masker::PastKvLenCache,
26    paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
27    pipeline::{
28        extract_logits,
29        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
30        EitherCache, IsqModel, KvCache, NormalCache, NormalLoadingMetadata, VisionModel,
31    },
32    serde_default_fn,
33    utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
34    vision_models::clip::{ClipConfig, ClipVisionTransformer},
35    AnyMoeConfig, AnyMoeExpertType,
36};
37
38use super::clip;
39
40#[derive(Debug, Clone, serde::Deserialize, Default)]
41pub struct EmbedLayerConfig {
42    pub hd_transform_order: Option<String>,
43    pub projection_cls: Option<String>,
44    pub use_hd_transform: Option<bool>,
45    pub with_learnable_separator: Option<bool>,
46}
47
48#[derive(Debug, Clone, serde::Deserialize, Default)]
49pub struct ImageProcessorConfig {
50    pub image_dim_out: usize,
51    pub name: String,
52    pub num_img_tokens: usize,
53    pub layer_idx: Option<isize>,
54    pub type_feature: Option<String>,
55}
56
57serde_default_fn!(bool, word_emb_default, false);
58
59#[derive(Debug, Clone, serde::Deserialize, Default)]
60pub struct Config {
61    pub vocab_size: usize,
62    pub hidden_act: Activation,
63    pub hidden_size: usize,
64    pub intermediate_size: usize,
65    pub num_hidden_layers: usize,
66    pub num_attention_heads: usize,
67    pub num_key_value_heads: usize,
68    pub rms_norm_eps: f64,
69    pub rope_theta: f64,
70    pub bos_token_id: Option<u32>,
71    pub eos_token_id: Option<u32>,
72    pub rope_scaling: Option<PhiRopeScalingConfig>,
73    pub max_position_embeddings: usize,
74    pub sliding_window: Option<usize>,
75    pub original_max_position_embeddings: usize,
76    pub embd_layer: EmbedLayerConfig,
77    pub img_processor: ImageProcessorConfig,
78    #[serde(alias = "quantization")]
79    pub quantization_config: Option<QuantizedConfig>,
80    #[serde(default = "word_emb_default")]
81    pub tie_word_embeddings: bool,
82}
83
84impl From<Config> for PhiRopeConfig {
85    fn from(val: Config) -> Self {
86        PhiRopeConfig {
87            rope_scaling: val.rope_scaling,
88            max_position_embeddings: val.max_position_embeddings,
89            original_max_position_embeddings: val.original_max_position_embeddings,
90            rope_theta: val.rope_theta,
91            head_dim: val.hidden_size / val.num_attention_heads,
92            partial_rotary_factor: None,
93        }
94    }
95}
96
97impl Config {
98    pub fn head_dim(&self) -> usize {
99        self.hidden_size / self.num_attention_heads
100    }
101}
102
103trait ModuleWithMetadata: Module + Debug + Send + Sync {
104    fn device(&self) -> Device;
105    fn dtype(&self) -> DType;
106}
107
108#[derive(Debug)]
109struct QuantMethodWrapper(Arc<dyn QuantMethod>);
110
111impl Module for QuantMethodWrapper {
112    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
113        self.0.forward(xs)
114    }
115}
116
117impl ModuleWithMetadata for QuantMethodWrapper {
118    fn device(&self) -> Device {
119        self.0.unquant_weight_bias().unwrap().0.device().clone()
120    }
121    fn dtype(&self) -> DType {
122        self.0.unquant_weight_bias().unwrap().0.dtype()
123    }
124}
125
126impl ModuleWithMetadata for candle_nn::Activation {
127    fn device(&self) -> Device {
128        unreachable!()
129    }
130    fn dtype(&self) -> DType {
131        unreachable!()
132    }
133}
134
135#[derive(Debug)]
136struct BigShapeWithOneHole((usize, usize, usize, usize, usize, ()));
137
138fn hole_size(el_count: usize, prod_d: usize, s: &dyn std::fmt::Debug) -> Result<usize> {
139    if prod_d == 0 {
140        candle_core::bail!("cannot reshape tensor of {el_count} elements to {s:?}")
141    }
142    if el_count % prod_d != 0 {
143        candle_core::bail!("cannot reshape tensor with {el_count} elements to {s:?}")
144    }
145    Ok(el_count / prod_d)
146}
147
148impl ShapeWithOneHole for BigShapeWithOneHole {
149    fn into_shape(self, el_count: usize) -> Result<Shape> {
150        let (d1, d2, d3, d4, d5, ()) = self.0;
151        let d = hole_size(el_count, d1 * d2 * d3 * d4 * d5, &self)?;
152        Ok((d1, d2, d3, d4, d5, d).into())
153    }
154}
155
156// =================== BASE LAYERS ===================
157
158struct Attention {
159    qkv_proj: Arc<dyn QuantMethod>,
160    o_proj: Arc<dyn QuantMethod>,
161    num_heads: usize,
162    num_kv_heads: usize,
163    head_dim: usize,
164    rotary_emb: Arc<PhiRotaryEmbedding>,
165    paged_attn: Option<PagedAttention>,
166    sdpa_params: SdpaParams,
167}
168
169impl Attention {
170    fn new(
171        rotary_emb: Arc<PhiRotaryEmbedding>,
172        cfg: &Config,
173        vb: ShardedVarBuilder,
174        paged_attn: Option<PagedAttention>,
175    ) -> Result<Self> {
176        let num_heads = cfg.num_attention_heads;
177        let num_kv_heads = cfg.num_key_value_heads;
178        let head_dim = cfg.head_dim();
179        let op_size = num_heads * head_dim + 2 * num_kv_heads * head_dim;
180
181        // No TP here.
182        let qkv_proj = mistralrs_quant::linear_no_bias(
183            cfg.hidden_size,
184            op_size,
185            &cfg.quantization_config,
186            vb.pp("qkv_proj"),
187        )?;
188
189        let o_proj = mistralrs_quant::linear_no_bias(
190            num_heads * head_dim,
191            cfg.hidden_size,
192            &cfg.quantization_config,
193            vb.pp("o_proj"),
194        )?;
195
196        Ok(Self {
197            qkv_proj,
198            o_proj,
199            rotary_emb,
200            num_heads,
201            num_kv_heads,
202            head_dim,
203            paged_attn,
204            sdpa_params: SdpaParams {
205                n_kv_groups: num_heads / num_kv_heads,
206                softcap: None,
207                softmax_scale: 1.0 / (head_dim as f32).sqrt(),
208                sliding_window: cfg.sliding_window,
209            },
210        })
211    }
212
213    #[allow(clippy::too_many_arguments)]
214    fn forward(
215        &self,
216        xs: &Tensor,
217        attention_mask: Option<&Tensor>,
218        seqlen_offsets: &[usize],
219        position_ids: &[usize],
220        kv_cache: &mut KvCache,
221        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
222        flash_params: &FlashParams,
223    ) -> Result<Tensor> {
224        let (b_sz, q_len, _) = xs.dims3()?;
225
226        let original_dtype = xs.dtype();
227        let mut xs = xs.clone();
228        if let Some(t) = self.qkv_proj.quantized_act_type() {
229            xs = xs.to_dtype(t)?;
230        }
231        let mut qkv = MatMul.qmethod_matmul(&xs, &*self.qkv_proj)?;
232        if self.qkv_proj.quantized_act_type().is_some() {
233            qkv = qkv.to_dtype(original_dtype)?;
234        }
235        let query_pos = self.num_heads * self.head_dim;
236        let q = qkv.narrow(D::Minus1, 0, query_pos)?;
237        let k = qkv.narrow(D::Minus1, query_pos, self.num_kv_heads * self.head_dim)?;
238        let v = qkv.narrow(
239            D::Minus1,
240            query_pos + self.num_kv_heads * self.head_dim,
241            self.num_kv_heads * self.head_dim,
242        )?;
243
244        let (q, k, v) = if q_len != 1 {
245            let q = q
246                .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
247                .transpose(1, 2)?;
248            let k = k
249                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
250                .transpose(1, 2)?;
251            let v = v
252                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
253                .transpose(1, 2)?;
254            (q, k, v)
255        } else {
256            let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
257            let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
258            let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
259            (q, k, v)
260        };
261
262        let (q, k) = self
263            .rotary_emb
264            .forward(&q, &k, seqlen_offsets, position_ids)?;
265
266        let mut attn_output = match &self.paged_attn {
267            Some(paged_attn) => match metadata {
268                Some(((key_cache, value_cache), input_metadata)) => paged_attn.forward(
269                    &q,
270                    &k.contiguous()?,
271                    &v.contiguous()?,
272                    attention_mask,
273                    Some(key_cache),
274                    Some(value_cache),
275                    input_metadata,
276                    &self.sdpa_params,
277                    Some(flash_params),
278                )?,
279                None => {
280                    // If we don't have metadata, we are most likely generating an imatrix so we don't want to populate that.
281                    // Generating the dummy metadata with the assumption that we are not generating text (only processing prompts).
282                    let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
283                    // Sanity check.
284                    assert!(attention_mask.is_some());
285                    paged_attn.forward(
286                        &q,
287                        &k.contiguous()?,
288                        &v.contiguous()?,
289                        attention_mask,
290                        None,
291                        None,
292                        &input_metadata,
293                        &self.sdpa_params,
294                        Some(flash_params),
295                    )?
296                }
297            },
298            None => {
299                let (k, v) = kv_cache.append(&k, &v)?;
300
301                Sdpa.run_attention(
302                    &q,
303                    &k,
304                    &v,
305                    attention_mask,
306                    Some(flash_params),
307                    &self.sdpa_params,
308                )?
309            }
310        };
311
312        if let Some(t) = self.qkv_proj.quantized_act_type() {
313            attn_output = attn_output.to_dtype(t)?;
314        }
315        attn_output = if attention_mask.is_some() {
316            attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?
317        } else {
318            attn_output.reshape((b_sz, q_len, ()))?
319        };
320        let mut res = MatMul.qmethod_matmul(&attn_output, &*self.o_proj)?;
321        if self.qkv_proj.quantized_act_type().is_some() {
322            res = res.to_dtype(original_dtype)?;
323        }
324        Ok(res)
325    }
326}
327
328#[derive(Clone)]
329struct Mlp {
330    gate_up_proj: Arc<dyn QuantMethod>,
331    down_proj: Arc<dyn QuantMethod>,
332    act_fn: Activation,
333    i_size: usize,
334    params: Vec<usize>,
335}
336
337impl Mlp {
338    fn new(cfg: &Config, vb: ShardedVarBuilder) -> Result<Self> {
339        let hidden_size = cfg.hidden_size;
340        let i_size = cfg.intermediate_size;
341
342        // No TP here.
343        let gate_up_proj = mistralrs_quant::linear_no_bias(
344            hidden_size,
345            2 * i_size,
346            &cfg.quantization_config,
347            vb.pp("gate_up_proj"),
348        )?;
349
350        let down_proj = mistralrs_quant::linear_no_bias(
351            i_size,
352            hidden_size,
353            &cfg.quantization_config,
354            vb.pp("down_proj"),
355        )?;
356
357        Ok(Self {
358            gate_up_proj,
359            down_proj,
360            act_fn: cfg.hidden_act,
361            i_size,
362            params: vec![hidden_size, i_size],
363        })
364    }
365}
366
367impl AnyMoeTrainableLayer for Mlp {}
368
369impl MlpLayer for Mlp {
370    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
371        let original_dtype = xs.dtype();
372        let mut xs = xs.clone();
373        if let Some(t) = self.gate_up_proj.quantized_act_type() {
374            xs = xs.to_dtype(t)?;
375        }
376        let up_states = MatMul.qmethod_matmul(&xs, &*self.gate_up_proj)?;
377        let gate = up_states.narrow(D::Minus1, 0, self.i_size)?;
378        let up_states = up_states.narrow(D::Minus1, self.i_size, self.i_size)?;
379        let up_states = (up_states * gate.apply(&self.act_fn))?;
380        let mut res = MatMul.qmethod_matmul(&up_states, &*self.down_proj)?;
381        if self.gate_up_proj.quantized_act_type().is_some() {
382            res = res.to_dtype(original_dtype)?;
383        }
384        Ok(res)
385    }
386    fn get_isq_layers(&mut self) -> Vec<&mut Arc<dyn QuantMethod>> {
387        vec![&mut self.gate_up_proj, &mut self.down_proj]
388    }
389    fn clone(&self) -> Box<dyn MlpLayer> {
390        Box::new(Clone::clone(self))
391    }
392    fn get_params(&self) -> &[usize] {
393        &self.params
394    }
395    fn hidden_act(&self) -> Activation {
396        self.act_fn
397    }
398    // gate_up, down
399    fn new_added_delta(&self, deltas: Vec<Option<Tensor>>) -> Result<Box<dyn MlpLayer>> {
400        let new_gate_up = if let Some(ref delta) = deltas[0] {
401            self.gate_up_proj.add_delta_w(delta)?
402        } else {
403            self.gate_up_proj.clone()
404        };
405        let new_down = if let Some(ref delta) = deltas[1] {
406            self.down_proj.add_delta_w(delta)?
407        } else {
408            self.down_proj.clone()
409        };
410
411        Ok(Box::new(Self {
412            gate_up_proj: new_gate_up,
413            down_proj: new_down,
414            act_fn: self.act_fn,
415            i_size: self.i_size,
416            params: self.params.clone(),
417        }))
418    }
419
420    fn dtype_device(&self) -> (DType, Device) {
421        self.gate_up_proj.dtype_and_device()
422    }
423}
424
425struct DecoderLayer {
426    self_attn: Attention,
427    mlp: Box<dyn MlpLayer>,
428    input_layernorm: RmsNorm,
429    post_attention_layernorm: RmsNorm,
430}
431
432impl DecoderLayer {
433    fn new(
434        rotary_emb: Arc<PhiRotaryEmbedding>,
435        cfg: &Config,
436        vb: ShardedVarBuilder,
437        mapper: &dyn DeviceMapper,
438        layer_idx: usize,
439        loading_isq: bool,
440        paged_attn: Option<PagedAttention>,
441    ) -> Result<Self> {
442        let self_attn = Attention::new(
443            rotary_emb,
444            cfg,
445            mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
446            paged_attn,
447        )?;
448        let mlp = Mlp::new(cfg, mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq))?;
449        let input_layernorm = RmsNorm::new(
450            cfg.hidden_size,
451            cfg.rms_norm_eps,
452            mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
453        )?;
454        let post_attention_layernorm = RmsNorm::new(
455            cfg.hidden_size,
456            cfg.rms_norm_eps,
457            mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
458        )?;
459        Ok(Self {
460            self_attn,
461            mlp: Box::new(mlp),
462            input_layernorm,
463            post_attention_layernorm,
464        })
465    }
466
467    #[allow(clippy::too_many_arguments)]
468    fn forward(
469        &self,
470        xs: &Tensor,
471        attention_mask: Option<&Tensor>,
472        seqlen_offsets: &[usize],
473        position_ids: &[usize],
474        kv_cache: &mut KvCache,
475        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
476        flash_params: &FlashParams,
477    ) -> Result<Tensor> {
478        let residual = xs;
479        let xs = self.input_layernorm.forward(xs)?;
480        let xs = self
481            .self_attn
482            .forward(
483                &xs,
484                attention_mask,
485                seqlen_offsets,
486                position_ids,
487                kv_cache,
488                metadata,
489                flash_params,
490            )
491            .unwrap();
492        let xs = (xs + residual)?;
493        let residual = &xs;
494        let xs = self
495            .mlp
496            .forward(&xs.apply(&self.post_attention_layernorm)?)?;
497        residual + xs
498    }
499}
500
501// =================== ============= ===================
502
503// =================== VISION LAYERS ===================
504
505const MAX_INPUT_ID: f64 = 1e9;
506
507#[derive(Debug)]
508struct EmbeddingLayers(Vec<Box<dyn ModuleWithMetadata>>);
509
510impl Module for EmbeddingLayers {
511    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
512        let mut xs = xs.clone();
513        for layer in &self.0 {
514            xs = layer.forward(&xs)?;
515        }
516        Ok(xs)
517    }
518}
519
520#[derive(Debug)]
521pub struct ImageEmbedding {
522    wte: candle_nn::Embedding,
523    image_dim_out: usize,
524    num_img_tokens: usize,
525    glb_gn: Option<Tensor>,
526    sub_gn: Option<Tensor>,
527    layers: EmbeddingLayers,
528    type_feature: String,
529    layer_idx: isize,
530    image_processor: ClipVisionTransformer,
531    hd_transform_order: String,
532    use_hd_transform: bool,
533    vocab_size: usize,
534    tensors: Vec<(String, Tensor)>,
535}
536
537pub(crate) const PHI3V_CLIP_CONFIG: ClipConfig = ClipConfig {
538    hidden_act: clip::Activation::QuickGelu,
539    hidden_size: 1024,
540    image_size: 336,
541    intermediate_size: 4096,
542    num_attention_heads: 16,
543    num_channels: 3,
544    num_hidden_layers: 24,
545    patch_size: 14,
546};
547
548impl ImageEmbedding {
549    fn new(
550        config: &Config,
551        wte: candle_nn::Embedding,
552        embed_config: &EmbedLayerConfig,
553        vb: ShardedVarBuilder,
554    ) -> Result<Self> {
555        let hidden_size = config.hidden_size;
556        if config.img_processor.name != "clip_vision_model" {
557            candle_core::bail!(
558                "img_processor=`{}` nor supported.",
559                config.img_processor.name
560            );
561        }
562        let image_dim_out = config.img_processor.image_dim_out;
563        let num_img_tokens = config.img_processor.num_img_tokens;
564
565        // CLIP image processor here...
566        let image_processor =
567            ClipVisionTransformer::new(vb.pp("img_processor.vision_model"), &PHI3V_CLIP_CONFIG)?;
568
569        // High dim transform
570        let use_hd_transform = embed_config.use_hd_transform.unwrap_or(false);
571        let with_learnable_separator = embed_config.with_learnable_separator.unwrap_or(false);
572        let hd_transform_order = embed_config
573            .hd_transform_order
574            .clone()
575            .unwrap_or("glb_sub".to_string());
576        assert_eq!(use_hd_transform, with_learnable_separator);
577        let (glb_gn, sub_gn) = if with_learnable_separator {
578            let glb_gn = vb.get((1, 1, image_dim_out * 4), "glb_GN")?;
579            let sub_gn = vb.get((1, 1, 1, image_dim_out * 4), "sub_GN")?;
580            (Some(glb_gn), Some(sub_gn))
581        } else {
582            (None, None)
583        };
584
585        // Inner projection
586        let projection_cls = embed_config
587            .projection_cls
588            .clone()
589            .unwrap_or("linear".to_string());
590
591        let mut tensors = Vec::new();
592        let layers: Vec<Box<dyn ModuleWithMetadata>> =
593            match (projection_cls.as_str(), use_hd_transform) {
594                ("linear", _) => {
595                    let a = mistralrs_quant::linear_b(
596                        image_dim_out,
597                        hidden_size,
598                        true,
599                        &None,
600                        vb.pp("img_projection"),
601                    )?;
602                    let (a_w, a_b) = a.unquant_weight_bias().unwrap();
603                    tensors.push(("img_projection.weight".to_string(), a_w));
604                    if let Some(b) = a_b {
605                        tensors.push(("img_projection.bias".to_string(), b));
606                    }
607                    vec![Box::new(QuantMethodWrapper(a))]
608                }
609                ("mlp", true) => {
610                    let dim_proj = hidden_size;
611                    let a = mistralrs_quant::linear_b(
612                        image_dim_out * 4,
613                        dim_proj,
614                        true,
615                        &None,
616                        vb.pp("img_projection.0"),
617                    )?;
618                    let (a_w, a_b) = a.unquant_weight_bias().unwrap();
619                    tensors.push(("img_projection.0.weight".to_string(), a_w));
620                    if let Some(b) = a_b {
621                        tensors.push(("img_projection.0.bias".to_string(), b));
622                    }
623                    let b = mistralrs_quant::linear_b(
624                        dim_proj,
625                        dim_proj,
626                        true,
627                        &None,
628                        vb.pp("img_projection.2"),
629                    )?;
630                    let (b_w, b_b) = b.unquant_weight_bias().unwrap();
631                    tensors.push(("img_projection.2.weight".to_string(), b_w));
632                    if let Some(b) = b_b {
633                        tensors.push(("img_projection.2.bias".to_string(), b));
634                    }
635                    vec![
636                        Box::new(QuantMethodWrapper(a)),
637                        Box::new(candle_nn::Activation::Gelu),
638                        Box::new(QuantMethodWrapper(b)),
639                    ]
640                }
641                ("mlp", false) => {
642                    let dim_proj = hidden_size;
643                    let a = mistralrs_quant::linear_b(
644                        image_dim_out,
645                        dim_proj,
646                        true,
647                        &None,
648                        vb.pp("img_projection.0"),
649                    )?;
650                    let (a_w, a_b) = a.unquant_weight_bias().unwrap();
651                    tensors.push(("img_projection.0.weight".to_string(), a_w));
652                    if let Some(b) = a_b {
653                        tensors.push(("img_projection.0.bias".to_string(), b));
654                    }
655                    let b = mistralrs_quant::linear_b(
656                        dim_proj,
657                        dim_proj,
658                        true,
659                        &None,
660                        vb.pp("img_projection.2"),
661                    )?;
662                    let (b_w, b_b) = b.unquant_weight_bias().unwrap();
663                    tensors.push(("img_projection.2.weight".to_string(), b_w));
664                    if let Some(b) = b_b {
665                        tensors.push(("img_projection.2.bias".to_string(), b));
666                    }
667                    vec![
668                        Box::new(QuantMethodWrapper(a)),
669                        Box::new(candle_nn::Activation::Gelu),
670                        Box::new(QuantMethodWrapper(b)),
671                    ]
672                }
673                _ => {
674                    candle_core::bail!("projection_cls=`{projection_cls}` not implemented.");
675                }
676            };
677
678        let layer_idx = config.img_processor.layer_idx.unwrap_or(-2);
679        let type_feature = config
680            .img_processor
681            .type_feature
682            .clone()
683            .unwrap_or("patch".to_string());
684
685        Ok(Self {
686            wte,
687            image_dim_out,
688            num_img_tokens,
689            glb_gn,
690            sub_gn,
691            layer_idx,
692            type_feature,
693            image_processor,
694            layers: EmbeddingLayers(layers),
695            hd_transform_order,
696            use_hd_transform,
697            vocab_size: config.vocab_size,
698            tensors,
699        })
700    }
701
702    fn get_image_features(&self, pixel_values: &Tensor) -> Result<Tensor> {
703        let hidden_states = self
704            .image_processor
705            .forward_get_hidden_states(&pixel_values.to_dtype(self.wte.embeddings().dtype())?)?;
706        let img_feature =
707            hidden_states[(hidden_states.len() as isize + self.layer_idx) as usize].clone();
708        if self.type_feature == "patch" {
709            img_feature.i((.., 1..))
710        } else if self.type_feature == "cls_patch" {
711            Ok(img_feature)
712        } else {
713            candle_core::bail!("Unsupported image feature type {}", self.type_feature)
714        }
715    }
716
717    #[allow(non_snake_case)]
718    fn forward(
719        &self,
720        input_ids: &Tensor,
721        pixel_values: &Tensor,
722        image_sizes: Option<Vec<(usize, usize)>>,
723    ) -> Result<Tensor> {
724        let input_ids = input_ids.reshape(((), input_ids.dim(D::Minus1)?))?;
725
726        let input_ids_lt = input_ids.lt(0.0f64)?;
727        let input_ids_gt = input_ids.gt(-MAX_INPUT_ID)?;
728        // positions = torch.nonzero((input_ids < 0) & (input_ids > -MAX_INPUT_ID), as_tuple=False)
729        let positions = input_ids_lt.bitwise_and(&input_ids_gt)?.nonzero()?;
730        let target_dev = self.layers.0[0].device();
731        let target_dtype = self.layers.0[0].dtype();
732
733        let mut select = false;
734        // If some, use hd transform case and it contains num_img_toks
735        let mut hd_transform = None;
736        let mut image_set_tensor = None;
737        if positions.dim(0)? > 0 {
738            select = true;
739            // input_ids[positions[:, 0], positions[:, 1]]
740            if self.use_hd_transform {
741                if let Some(image_sizes_ref) = image_sizes.as_ref() {
742                    assert_eq!(pixel_values.dims().len(), 5);
743                    let bs = pixel_values.dim(0)?;
744                    let img_features = self.get_image_features(&pixel_values.flatten(0, 1)?)?;
745                    let base_feat_dim = (img_features.dims()[1] as f32).sqrt() as usize;
746                    assert_eq!(base_feat_dim, 24);
747
748                    // bs x max_num_crops x (24x24) x C
749                    let img_features =
750                        img_features.reshape((bs, (), base_feat_dim.pow(2), self.image_dim_out))?;
751                    let C = self.image_dim_out;
752                    let H = base_feat_dim;
753
754                    let mut output_imgs = Vec::new();
755                    let mut output_len = Vec::new();
756                    for (bs_, &(h, w)) in image_sizes_ref.iter().enumerate().take(bs) {
757                        let h = h / 336;
758                        let w = w / 336;
759                        let B_ = h * w;
760
761                        // 1 x (24x24) x 1024
762                        let global_img_feature = img_features.i((bs_, ..1))?;
763
764                        // 1 x 12 x 12 x 4096
765                        let glb_img = global_img_feature
766                            .reshape((1, H, H, C))?
767                            .reshape((1, H / 2, 2, H / 2, 2, C))?
768                            .contiguous()?
769                            .permute((0, 1, 3, 2, 4, 5))?
770                            .reshape((1, H / 2, H / 2, 4 * C))?
771                            .contiguous()?;
772                        let temp_glbl_gn = self
773                            .sub_gn
774                            .as_ref()
775                            .expect("Need `sub_gn` if `use_hd_transform`")
776                            .repeat((1, H / 2, 1, 1))?;
777
778                        // 1 x 156 x 4096
779                        let glb_img =
780                            Tensor::cat(&[glb_img, temp_glbl_gn], 2)?.reshape((1, (), 4 * C))?;
781
782                        // (max_num_crops-1) x (12x12) x C
783                        let sub_img = img_features.i((bs_, 1..))?;
784
785                        // 16x574x1024
786                        // Get rid of padding sub_img
787                        let sub_img = sub_img.i(..B_)?;
788
789                        // (num_crops, 12, 2, 12, 2, 1024) -> (num_crops, 12, 12, 2, 2, 1024) -> (num_crops, 12*12, 4*1024)
790                        let sub_img = sub_img
791                            .reshape((B_, H, H, C))?
792                            .reshape((B_, H / 2, 2, H / 2, 2, C))?
793                            .contiguous()?
794                            .permute((0, 1, 3, 2, 4, 5))?
795                            .reshape((B_, (), 4 * C))?
796                            .contiguous()?;
797                        let sub_img = sub_img
798                            .reshape(BigShapeWithOneHole((1usize, h, w, 12usize, 12usize, ())))?
799                            .permute((0, 1, 3, 2, 4, 5))?
800                            .reshape((1, h * 12, w * 12, 4 * C))?;
801                        let temp_sub_gn = self
802                            .sub_gn
803                            .as_ref()
804                            .expect("Need `sub_gn` if `use_hd_transform`")
805                            .repeat((1, h * 12, 1, 1))?;
806
807                        let sub_img =
808                            Tensor::cat(&[sub_img, temp_sub_gn], 2)?.reshape((1, (), 4 * C))?;
809
810                        // (1, num_img_tokens, 1024*4)
811
812                        match self.hd_transform_order.as_str() {
813                            "glb_sub" => {
814                                output_imgs.push(Tensor::cat(
815                                    &[
816                                        glb_img,
817                                        self.glb_gn
818                                            .as_ref()
819                                            .expect("Need `glb_gn` if `use_hd_transform`")
820                                            .clone(),
821                                        sub_img,
822                                    ],
823                                    1,
824                                )?);
825                            }
826                            "sub_glb" => {
827                                output_imgs.push(Tensor::cat(
828                                    &[
829                                        sub_img,
830                                        self.glb_gn
831                                            .as_ref()
832                                            .expect("Need `glb_gn` if `use_hd_transform`")
833                                            .clone(),
834                                        glb_img,
835                                    ],
836                                    1,
837                                )?);
838                            }
839                            other => {
840                                candle_core::bail!("Invalid hd_transform_order=`{other}`");
841                            }
842                        }
843
844                        let temp_len = (h * w + 1) * 144 + 1 + (h + 1) * 12;
845                        assert_eq!(temp_len, output_imgs.last().unwrap().dims()[1]);
846                        output_len.push(temp_len);
847                    }
848
849                    hd_transform = Some(output_len);
850                    let mut image_set_tensor_inner = Vec::new();
851                    for img in output_imgs {
852                        let layerout = self
853                            .layers
854                            .forward(&img.to_device(&target_dev)?.to_dtype(target_dtype)?)?;
855                        image_set_tensor_inner.push(layerout);
856                    }
857                    image_set_tensor = Some(Either::Left(image_set_tensor_inner));
858                }
859            } else if pixel_values.dims().len() == 4 {
860                let tt = self
861                    .get_image_features(pixel_values)?
862                    .to_device(&target_dev)?
863                    .to_dtype(target_dtype)?
864                    .reshape(((), self.image_dim_out))?;
865                let image_set_tensor_inner = self.layers.forward(&tt)?;
866                image_set_tensor = Some(Either::Right(image_set_tensor_inner));
867            } else if pixel_values.dims().len() == 3 {
868                let tt = pixel_values
869                    .to_device(&target_dev)?
870                    .to_dtype(target_dtype)?
871                    .reshape(((), self.image_dim_out))?;
872                let image_set_tensor_inner = self.layers.forward(&tt)?;
873                image_set_tensor = Some(Either::Right(image_set_tensor_inner));
874            } else {
875                unreachable!()
876            }
877        }
878
879        let input_ids = input_ids.clamp(0.0, self.vocab_size as f64)?;
880        let mut hidden_states = self.wte.forward(&input_ids)?;
881        if select {
882            match (hd_transform, image_set_tensor) {
883                (Some(output_lens), Some(Either::Left(image_set_tensors))) => {
884                    let mut idx = 0;
885                    for (i, cnt) in output_lens.into_iter().enumerate() {
886                        let img_set_tensor = image_set_tensors[i]
887                            .to_device(&target_dev)?
888                            .to_dtype(target_dtype)?;
889                        // hidden_states[positions[idx, 0], positions[idx, 1] : positions[idx, 1] + cnt] = ...
890                        let p_0 = positions.i((idx, 0))?.to_scalar::<u32>()? as usize;
891                        let p_1 = positions.i((idx, 1))?.to_scalar::<u32>()? as usize;
892                        hidden_states = hidden_states.slice_assign(
893                            &[p_0..p_0 + 1, p_1..p_1 + cnt, 0..img_set_tensor.dims()[2]],
894                            &img_set_tensor,
895                        )?;
896                        idx += cnt;
897                    }
898                }
899                (None, Some(Either::Right(image_set_tensor))) => {
900                    let mut idx = 0;
901                    // Know len(img_embeds) == pixel_values.dim(0) == len(selected_g_values)
902                    // https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/dbcdaaacf52c8e40cf8de6d6ffa6ff6860e5f256/image_embedding_phi3_v.py#L259
903                    for i in 0..pixel_values.dim(0)? {
904                        let cnt = self.num_img_tokens;
905                        let img_set_tensor = image_set_tensor
906                            .i(i * cnt..(i + 1) * cnt)?
907                            .to_device(&target_dev)?
908                            .to_dtype(target_dtype)?;
909                        let p_0 = positions.i((idx, 0))?.to_scalar::<u32>()? as usize;
910                        let p_1 = positions.i((idx, 1))?.to_scalar::<u32>()? as usize;
911                        // hidden_states[positions[idx, 0], positions[idx, 1] : positions[idx, 1] + cnt] = ...
912                        hidden_states = hidden_states.slice_assign(
913                            &[p_0..p_0 + 1, p_1..p_1 + cnt, 0..img_set_tensor.dims()[2]],
914                            &img_set_tensor,
915                        )?;
916                        idx += cnt;
917                    }
918                }
919                _ => unreachable!(),
920            }
921        }
922
923        Ok(hidden_states)
924    }
925
926    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
927        let uvb = UnVarBuilder::new();
928
929        if let Some(glb_gn) = self.glb_gn.clone() {
930            uvb.add_tensor("glb_GN", glb_gn);
931        }
932        if let Some(sub_gn) = self.sub_gn.clone() {
933            uvb.add_tensor("sub_GN", sub_gn);
934        }
935        uvb.extend(self.tensors.clone());
936        uvb.pp("img_processor.vision_model")
937            .extend(self.image_processor.residual_tensors());
938
939        uvb.to_safetensors()
940    }
941}
942
943// =================== ============= ===================
944
945pub struct Model {
946    vision_embed_tokens: ImageEmbedding,
947    embed_tokens: candle_nn::Embedding,
948    layers: Vec<DecoderLayer>,
949    norm: RmsNorm,
950    lm_head: Arc<dyn QuantMethod>,
951    device: Device,
952    cache: EitherCache,
953    max_seq_len: usize,
954    mapper: Box<dyn DeviceMapper + Send + Sync>,
955    sliding_window: Option<usize>,
956    cfg: ModelConfigMetadata,
957}
958
959impl Model {
960    pub fn new(
961        cfg: &Config,
962        vb: ShardedVarBuilder,
963        _is_gptx: bool,
964        normal_loading_metadata: NormalLoadingMetadata,
965        attention_mechanism: AttentionImplementation,
966    ) -> Result<Self> {
967        let mapper = normal_loading_metadata.mapper;
968        let vb_m = vb.pp("model");
969
970        let embed_tokens = layers::embedding(
971            cfg.vocab_size,
972            cfg.hidden_size,
973            mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
974            &cfg.quantization_config,
975        )?;
976        let vision_embed_tokens = ImageEmbedding::new(
977            cfg,
978            embed_tokens.clone(),
979            &cfg.embd_layer,
980            mapper.set_nm_device(vb_m.pp("vision_embed_tokens"), false),
981        )?;
982        let vb_l = vb_m.pp("layers");
983        let mut ropes = HashMap::new();
984        for layer_idx in 0..cfg.num_hidden_layers {
985            let device = mapper
986                .device_for(layer_idx, false)
987                .unwrap_or(&normal_loading_metadata.real_device);
988            ropes.insert(
989                device.location(),
990                Arc::new(PhiRotaryEmbedding::new(vb.dtype(), cfg.clone(), device)?),
991            );
992        }
993        let layers = NiceProgressBar::<_, 'b'>(
994            0..cfg.num_hidden_layers,
995            "Loading repeating layers",
996            &normal_loading_metadata.multi_progress,
997        )
998        .par_iter_if_isq(|layer_idx| {
999            let device = mapper
1000                .device_for(layer_idx, false)
1001                .unwrap_or(&normal_loading_metadata.real_device);
1002            let rotary_emb = ropes
1003                .get(&device.location())
1004                .expect("No RoPE for device location!")
1005                .clone();
1006            let paged_attn = match &attention_mechanism {
1007                AttentionImplementation::Eager => None,
1008                AttentionImplementation::PagedAttention => {
1009                    Some(PagedAttention::new(cfg.head_dim(), device, None)?)
1010                }
1011            };
1012            DecoderLayer::new(
1013                rotary_emb,
1014                cfg,
1015                vb_l.pp(layer_idx),
1016                &*mapper,
1017                layer_idx,
1018                normal_loading_metadata.loading_isq,
1019                paged_attn,
1020            )
1021        })?;
1022        let norm = RmsNorm::new(
1023            cfg.hidden_size,
1024            cfg.rms_norm_eps,
1025            mapper.set_nm_device(vb_m.pp("norm"), false),
1026        )?;
1027        let lm_head = if !cfg.tie_word_embeddings {
1028            ReplicatedLayer::new(
1029                cfg.hidden_size,
1030                cfg.vocab_size,
1031                &cfg.quantization_config,
1032                false,
1033                mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
1034            )?
1035        } else {
1036            ReplicatedLayer::from_linear(candle_nn::Linear::new(
1037                mapper.cast_nm_device(
1038                    embed_tokens.embeddings(),
1039                    normal_loading_metadata.loading_isq,
1040                )?,
1041                None,
1042            ))?
1043        };
1044
1045        Ok(Self {
1046            vision_embed_tokens,
1047            layers,
1048            norm,
1049            lm_head,
1050            device: normal_loading_metadata.real_device,
1051            cache: EitherCache::Normal(NormalCache::new_sliding(
1052                cfg.num_hidden_layers,
1053                cfg.max_position_embeddings,
1054                cfg.sliding_window,
1055            )),
1056            max_seq_len: cfg.max_position_embeddings,
1057            sliding_window: cfg.sliding_window,
1058            embed_tokens,
1059            cfg: ModelConfigMetadata {
1060                max_seq_len: cfg.max_position_embeddings,
1061                num_layers: cfg.num_hidden_layers,
1062                hidden_size: cfg.hidden_size,
1063                num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
1064                num_kv_heads: (cfg.num_key_value_heads / mapper.get_comm_for(0)?.world_size())
1065                    .max(1),
1066                sliding_window: cfg.sliding_window,
1067                k_head_dim: cfg.head_dim(),
1068                v_head_dim: cfg.head_dim(),
1069            },
1070            mapper,
1071        })
1072    }
1073
1074    #[allow(clippy::too_many_arguments)]
1075    pub fn forward(
1076        &self,
1077        input_ids: &Tensor,
1078        pixel_values: Option<Tensor>,
1079        seqlen_offsets: &[usize],
1080        position_ids: &[usize],
1081        context_lens: Vec<(usize, usize)>,
1082        image_sizes: Option<Vec<(usize, usize)>>,
1083        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
1084        flash_params: &FlashParams,
1085    ) -> Result<Tensor> {
1086        let mut xs = if let Some(ref pixel_values) = pixel_values {
1087            self.vision_embed_tokens
1088                .forward(input_ids, pixel_values, image_sizes)?
1089        } else {
1090            self.embed_tokens.forward(input_ids)?
1091        };
1092        let cache = &mut self.cache.normal().0;
1093        let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
1094            input_ids,
1095            metadata
1096                .as_ref()
1097                .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
1098                .unwrap_or(&*cache as &dyn PastKvLenCache),
1099            self.sliding_window,
1100            xs.dtype(),
1101            self.cfg.num_attn_heads,
1102        )?;
1103        let attention_mask = attention_mask.filter(|_| {
1104            metadata
1105                .as_ref()
1106                .map(|(_, meta)| meta.is_first_prompt_chunk)
1107                .unwrap_or(true)
1108        });
1109
1110        for (i, layer) in self.layers.iter().enumerate() {
1111            xs = self.mapper.map(xs, i)?;
1112            xs = layer.forward(
1113                &xs,
1114                attention_mask
1115                    .as_ref()
1116                    .map(|m| m.to_device(xs.device()).unwrap())
1117                    .as_ref(),
1118                seqlen_offsets,
1119                position_ids,
1120                &mut cache[i],
1121                metadata
1122                    .as_ref()
1123                    .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
1124                flash_params,
1125            )?
1126        }
1127        let xs = xs.to_device(&self.device)?;
1128        let mut xs = xs.apply(&self.norm)?;
1129        if let Some(t) = self.lm_head.quantized_act_type() {
1130            xs = xs.to_dtype(t)?;
1131        }
1132        extract_logits(&MatMul.qmethod_matmul(&xs, &*self.lm_head)?, context_lens)
1133    }
1134}
1135
1136impl IsqModel for Model {
1137    fn get_layers(
1138        &mut self,
1139    ) -> (
1140        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
1141        &dyn DeviceMapper,
1142    ) {
1143        let mut tensors = Vec::new();
1144        tensors.push((&mut self.lm_head, None));
1145        for (i, layer) in self.layers.iter_mut().enumerate() {
1146            tensors.push((&mut layer.self_attn.qkv_proj, Some(i)));
1147            tensors.push((&mut layer.self_attn.o_proj, Some(i)));
1148            tensors.extend(
1149                layer
1150                    .mlp
1151                    .get_isq_layers()
1152                    .into_iter()
1153                    .map(|m| (m, Some(i)))
1154                    .collect::<Vec<_>>(),
1155            );
1156        }
1157        (tensors, &*self.mapper)
1158    }
1159
1160    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
1161        let uvb = UnVarBuilder::new();
1162
1163        let uvb_m = uvb.pp("model");
1164        uvb_m.pp("embed_tokens").add(&self.embed_tokens);
1165        uvb_m.pp("norm").add(&self.norm);
1166        uvb_m
1167            .pp("vision_embed_tokens")
1168            .extend(self.vision_embed_tokens.residual_tensors());
1169
1170        for (layer_idx, layer) in self.layers.iter().enumerate() {
1171            let uvb_l = uvb_m.pp("layers").pp(layer_idx);
1172            uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
1173            uvb_l
1174                .pp("post_attention_layernorm")
1175                .add(&layer.post_attention_layernorm);
1176        }
1177
1178        uvb.to_safetensors()
1179    }
1180}
1181
1182#[derive(Default)]
1183pub(crate) struct Phi3VisionSpecificArgs {
1184    pub image_sizes: Option<Vec<(usize, usize)>>,
1185}
1186
1187impl VisionModel for Model {
1188    fn forward(
1189        &self,
1190        input_ids: &Tensor,
1191        pixel_values: Option<Tensor>,
1192        seqlen_offsets: &[usize],
1193        context_lens: Vec<(usize, usize)>,
1194        position_ids: Vec<usize>,
1195        model_specific_args: Box<dyn Any>,
1196        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
1197        flash_params: &FlashParams,
1198    ) -> Result<Tensor> {
1199        let Phi3VisionSpecificArgs { image_sizes } = *model_specific_args
1200            .downcast()
1201            .expect("Cannot downcast into `Phi3VisionSpecificArgs`");
1202        self.forward(
1203            input_ids,
1204            pixel_values,
1205            seqlen_offsets,
1206            &position_ids,
1207            context_lens,
1208            image_sizes,
1209            metadata,
1210            flash_params,
1211        )
1212    }
1213    fn cache(&self) -> &EitherCache {
1214        &self.cache
1215    }
1216    fn cache_mut(&mut self) -> &mut EitherCache {
1217        &mut self.cache
1218    }
1219    fn device(&self) -> &Device {
1220        &self.device
1221    }
1222    fn max_seq_len(&self) -> usize {
1223        self.max_seq_len
1224    }
1225    fn config(&self) -> &ModelConfigMetadata {
1226        &self.cfg
1227    }
1228    fn default_model_specific_args(&self, _input_ids: &Tensor) -> Box<dyn Any> {
1229        Box::new(Phi3VisionSpecificArgs::default())
1230    }
1231}
1232
1233impl AnyMoeBaseModelMixin for Model {
1234    fn get_mlps(&self) -> Vec<&dyn MlpLayer> {
1235        let mut mlps = Vec::new();
1236        for layer in &self.layers {
1237            mlps.push(&*layer.mlp);
1238        }
1239        mlps
1240    }
1241    fn get_mlps_mut(&mut self) -> Vec<&mut Box<dyn MlpLayer>> {
1242        let mut mlps = Vec::new();
1243        for layer in &mut self.layers {
1244            mlps.push(&mut layer.mlp);
1245        }
1246        mlps
1247    }
1248    fn create_anymoe_layers(
1249        &mut self,
1250        additional_vbs: Vec<ShardedVarBuilder>,
1251        config: AnyMoeConfig,
1252        (prefix, mlp): (String, String),
1253        mut layers: Vec<usize>,
1254        expert_type: AnyMoeExpertType,
1255        gate_vb: Option<ShardedVarBuilder>,
1256    ) -> Result<()> {
1257        let mut experts: Vec<Vec<Box<dyn MlpLayer>>> = Vec::new();
1258        if layers.is_empty() {
1259            layers = (0..self.layers.len()).collect::<Vec<_>>();
1260        }
1261        for _ in 0..layers.len() {
1262            experts.push(Vec::new());
1263        }
1264        for vb in additional_vbs {
1265            let vb = vb.pp(&prefix);
1266            for (layer, row) in experts.iter_mut().enumerate() {
1267                if !layers.contains(&layer) {
1268                    continue;
1269                }
1270
1271                let intermediate_size = self.layers[layer].mlp.get_params()[1];
1272                let hidden_size = self.layers[layer].mlp.get_params()[0];
1273                match expert_type {
1274                    AnyMoeExpertType::FineTuned => {
1275                        row.push(Box::new(Mlp::new(
1276                            &Config {
1277                                intermediate_size: self.layers[layer].mlp.get_params()[1],
1278                                hidden_size: self.layers[layer].mlp.get_params()[0],
1279                                ..Default::default()
1280                            },
1281                            vb.pp(layer).pp(&mlp),
1282                        )?));
1283                    }
1284                    AnyMoeExpertType::LoraAdapter {
1285                        rank,
1286                        alpha,
1287                        ref target_modules,
1288                    } => {
1289                        let vb_mlp = vb.pp(layer).pp(&mlp);
1290
1291                        let gate_up_proj_delta =
1292                            if target_modules.contains(&"gate_up_proj".to_string()) {
1293                                Some(get_delta_from_lora_ab!(
1294                                    vb_mlp,
1295                                    rank,
1296                                    alpha,
1297                                    (hidden_size, 2 * intermediate_size),
1298                                    "gate_up_proj"
1299                                ))
1300                            } else {
1301                                None
1302                            };
1303                        let down_proj_delta = if target_modules.contains(&"down_proj".to_string()) {
1304                            Some(get_delta_from_lora_ab!(
1305                                vb_mlp,
1306                                rank,
1307                                alpha,
1308                                (hidden_size, intermediate_size),
1309                                "down_proj"
1310                            ))
1311                        } else {
1312                            None
1313                        };
1314
1315                        row.push(
1316                            self.layers[layer]
1317                                .mlp
1318                                .new_added_delta(vec![gate_up_proj_delta, down_proj_delta])?,
1319                        );
1320                    }
1321                }
1322            }
1323        }
1324        for (layer, expert) in layers.into_iter().zip(experts) {
1325            let mut experts_all = vec![self.layers[layer].mlp.clone()];
1326            experts_all.extend(expert);
1327            let (dtype, device) = self.layers[layer].mlp.dtype_device();
1328            self.layers[layer].mlp = Box::new(MoeMlp::new(
1329                experts_all,
1330                config.clone(),
1331                dtype,
1332                &device,
1333                layer,
1334                gate_vb.as_ref(),
1335            )?);
1336        }
1337        Ok(())
1338    }
1339    fn amoe_supported(&self) -> bool {
1340        true
1341    }
1342}