mistralrs_core/vision_models/phi3/
mod.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3pub(crate) mod phi3_inputs_processor;
4
5// This implementation is based on:
6// https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/modeling_phi3.py
7use candle_core::{
8    shape::ShapeWithOneHole, DType, Device, IndexOp, Module, Result, Shape, Tensor, D,
9};
10use either::Either;
11use mistralrs_quant::{QuantMethod, QuantizedConfig, ReplicatedLayer, ShardedVarBuilder};
12use std::{any::Any, collections::HashMap, fmt::Debug, sync::Arc};
13
14use crate::{
15    amoe::{AnyMoeBaseModelMixin, AnyMoeTrainableLayer, MlpLayer, MoeMlp},
16    attention::SdpaParams,
17    device_map::DeviceMapper,
18    get_delta_from_lora_ab,
19    layers::{
20        self, Activation, CausalMasker, MatMul, PhiRopeConfig, PhiRopeScalingConfig,
21        PhiRotaryEmbedding, RmsNorm, Sdpa,
22    },
23    layers_masker::PastKvLenCache,
24    ops::{BitWiseOp, NonZeroOp},
25    paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
26    pipeline::{
27        extract_logits,
28        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
29        EitherCache, IsqModel, KvCache, NormalCache, NormalLoadingMetadata, VisionModel,
30    },
31    serde_default_fn,
32    utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
33    vision_models::clip::{ClipConfig, ClipVisionTransformer},
34    AnyMoeConfig, AnyMoeExpertType,
35};
36
37use super::clip;
38
39#[derive(Debug, Clone, serde::Deserialize, Default)]
40pub struct EmbedLayerConfig {
41    pub hd_transform_order: Option<String>,
42    pub projection_cls: Option<String>,
43    pub use_hd_transform: Option<bool>,
44    pub with_learnable_separator: Option<bool>,
45}
46
47#[derive(Debug, Clone, serde::Deserialize, Default)]
48pub struct ImageProcessorConfig {
49    pub image_dim_out: usize,
50    pub name: String,
51    pub num_img_tokens: usize,
52    pub layer_idx: Option<isize>,
53    pub type_feature: Option<String>,
54}
55
56serde_default_fn!(bool, d_flash_attn, false);
57serde_default_fn!(bool, word_emb_default, false);
58
59#[derive(Debug, Clone, serde::Deserialize, Default)]
60pub struct Config {
61    pub vocab_size: usize,
62    pub hidden_act: Activation,
63    pub hidden_size: usize,
64    pub intermediate_size: usize,
65    pub num_hidden_layers: usize,
66    pub num_attention_heads: usize,
67    pub num_key_value_heads: usize,
68    pub rms_norm_eps: f64,
69    pub rope_theta: f64,
70    pub bos_token_id: Option<u32>,
71    pub eos_token_id: Option<u32>,
72    pub rope_scaling: Option<PhiRopeScalingConfig>,
73    pub max_position_embeddings: usize,
74    #[serde(default = "d_flash_attn")]
75    pub use_flash_attn: bool,
76    pub sliding_window: Option<usize>,
77    pub original_max_position_embeddings: usize,
78    pub embd_layer: EmbedLayerConfig,
79    pub img_processor: ImageProcessorConfig,
80    pub quantization_config: Option<QuantizedConfig>,
81    #[serde(default = "word_emb_default")]
82    pub tie_word_embeddings: bool,
83}
84
85impl From<Config> for PhiRopeConfig {
86    fn from(val: Config) -> Self {
87        PhiRopeConfig {
88            rope_scaling: val.rope_scaling,
89            max_position_embeddings: val.max_position_embeddings,
90            original_max_position_embeddings: val.original_max_position_embeddings,
91            rope_theta: val.rope_theta,
92            head_dim: val.hidden_size / val.num_attention_heads,
93            partial_rotary_factor: None,
94        }
95    }
96}
97
98impl Config {
99    pub fn head_dim(&self) -> usize {
100        self.hidden_size / self.num_attention_heads
101    }
102}
103
104trait ModuleWithMetadata: Module + Debug + Send + Sync {
105    fn device(&self) -> Device;
106    fn dtype(&self) -> DType;
107}
108
109#[derive(Debug)]
110struct QuantMethodWrapper(Arc<dyn QuantMethod>);
111
112impl Module for QuantMethodWrapper {
113    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
114        self.0.forward(xs)
115    }
116}
117
118impl ModuleWithMetadata for QuantMethodWrapper {
119    fn device(&self) -> Device {
120        self.0.unquant_weight_bias().unwrap().0.device().clone()
121    }
122    fn dtype(&self) -> DType {
123        self.0.unquant_weight_bias().unwrap().0.dtype()
124    }
125}
126
127impl ModuleWithMetadata for candle_nn::Activation {
128    fn device(&self) -> Device {
129        unreachable!()
130    }
131    fn dtype(&self) -> DType {
132        unreachable!()
133    }
134}
135
136#[derive(Debug)]
137struct BigShapeWithOneHole((usize, usize, usize, usize, usize, ()));
138
139fn hole_size(el_count: usize, prod_d: usize, s: &dyn std::fmt::Debug) -> Result<usize> {
140    if prod_d == 0 {
141        candle_core::bail!("cannot reshape tensor of {el_count} elements to {s:?}")
142    }
143    if el_count % prod_d != 0 {
144        candle_core::bail!("cannot reshape tensor with {el_count} elements to {s:?}")
145    }
146    Ok(el_count / prod_d)
147}
148
149impl ShapeWithOneHole for BigShapeWithOneHole {
150    fn into_shape(self, el_count: usize) -> Result<Shape> {
151        let (d1, d2, d3, d4, d5, ()) = self.0;
152        let d = hole_size(el_count, d1 * d2 * d3 * d4 * d5, &self)?;
153        Ok((d1, d2, d3, d4, d5, d).into())
154    }
155}
156
157// =================== BASE LAYERS ===================
158
159struct Attention {
160    qkv_proj: Arc<dyn QuantMethod>,
161    o_proj: Arc<dyn QuantMethod>,
162    num_heads: usize,
163    num_kv_heads: usize,
164    head_dim: usize,
165    rotary_emb: Arc<PhiRotaryEmbedding>,
166    paged_attn: Option<PagedAttention>,
167    sdpa_params: SdpaParams,
168}
169
170impl Attention {
171    fn new(
172        rotary_emb: Arc<PhiRotaryEmbedding>,
173        cfg: &Config,
174        vb: ShardedVarBuilder,
175        paged_attn: Option<PagedAttention>,
176    ) -> Result<Self> {
177        let num_heads = cfg.num_attention_heads;
178        let num_kv_heads = cfg.num_key_value_heads;
179        let head_dim = cfg.head_dim();
180        let op_size = num_heads * head_dim + 2 * num_kv_heads * head_dim;
181
182        // No TP here.
183        let qkv_proj = mistralrs_quant::linear_no_bias(
184            cfg.hidden_size,
185            op_size,
186            &cfg.quantization_config,
187            vb.pp("qkv_proj"),
188        )?;
189
190        let o_proj = mistralrs_quant::linear_no_bias(
191            num_heads * head_dim,
192            cfg.hidden_size,
193            &cfg.quantization_config,
194            vb.pp("o_proj"),
195        )?;
196
197        Ok(Self {
198            qkv_proj,
199            o_proj,
200            rotary_emb,
201            num_heads,
202            num_kv_heads,
203            head_dim,
204            paged_attn,
205            sdpa_params: SdpaParams {
206                n_kv_groups: num_heads / num_kv_heads,
207                use_flash_attn: cfg.use_flash_attn,
208                softcap: None,
209                softmax_scale: 1.0 / (head_dim as f32).sqrt(),
210                sliding_window: cfg.sliding_window,
211            },
212        })
213    }
214
215    #[allow(clippy::too_many_arguments)]
216    fn forward(
217        &self,
218        xs: &Tensor,
219        attention_mask: Option<&Tensor>,
220        seqlen_offsets: &[usize],
221        position_ids: &[usize],
222        kv_cache: &mut KvCache,
223        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
224        flash_params: &FlashParams,
225    ) -> Result<Tensor> {
226        let (b_sz, q_len, _) = xs.dims3()?;
227
228        let original_dtype = xs.dtype();
229        let mut xs = xs.clone();
230        if let Some(t) = self.qkv_proj.quantized_act_type() {
231            xs = xs.to_dtype(t)?;
232        }
233        let mut qkv = MatMul.qmethod_matmul(&xs, &*self.qkv_proj)?;
234        if self.qkv_proj.quantized_act_type().is_some() {
235            qkv = qkv.to_dtype(original_dtype)?;
236        }
237        let query_pos = self.num_heads * self.head_dim;
238        let q = qkv.narrow(D::Minus1, 0, query_pos)?;
239        let k = qkv.narrow(D::Minus1, query_pos, self.num_kv_heads * self.head_dim)?;
240        let v = qkv.narrow(
241            D::Minus1,
242            query_pos + self.num_kv_heads * self.head_dim,
243            self.num_kv_heads * self.head_dim,
244        )?;
245
246        let (q, k, v) = if q_len != 1 {
247            let q = q
248                .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
249                .transpose(1, 2)?;
250            let k = k
251                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
252                .transpose(1, 2)?;
253            let v = v
254                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
255                .transpose(1, 2)?;
256            (q, k, v)
257        } else {
258            let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
259            let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
260            let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
261            (q, k, v)
262        };
263
264        let (q, k) = self
265            .rotary_emb
266            .forward(&q, &k, seqlen_offsets, position_ids)?;
267
268        let mut attn_output = match &self.paged_attn {
269            Some(paged_attn) => match metadata {
270                Some(((key_cache, value_cache), input_metadata)) => paged_attn.forward(
271                    &q,
272                    &k.contiguous()?,
273                    &v.contiguous()?,
274                    attention_mask,
275                    Some(key_cache),
276                    Some(value_cache),
277                    input_metadata,
278                    &self.sdpa_params,
279                    Some(flash_params),
280                )?,
281                None => {
282                    // If we don't have metadata, we are most likely generating an imatrix so we don't want to populate that.
283                    // Generating the dummy metadata with the assumption that we are not generating text (only processing prompts).
284                    let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
285                    // Sanity check.
286                    assert!(attention_mask.is_some());
287                    paged_attn.forward(
288                        &q,
289                        &k.contiguous()?,
290                        &v.contiguous()?,
291                        attention_mask,
292                        None,
293                        None,
294                        &input_metadata,
295                        &self.sdpa_params,
296                        Some(flash_params),
297                    )?
298                }
299            },
300            None => {
301                let (k, v) = kv_cache.append(&k, &v)?;
302
303                Sdpa.run_attention(
304                    &q,
305                    &k,
306                    &v,
307                    attention_mask,
308                    Some(flash_params),
309                    &self.sdpa_params,
310                )?
311            }
312        };
313
314        if let Some(t) = self.qkv_proj.quantized_act_type() {
315            attn_output = attn_output.to_dtype(t)?;
316        }
317        attn_output = if attention_mask.is_some() {
318            attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?
319        } else {
320            attn_output.reshape((b_sz, q_len, ()))?
321        };
322        let mut res = MatMul.qmethod_matmul(&attn_output, &*self.o_proj)?;
323        if self.qkv_proj.quantized_act_type().is_some() {
324            res = res.to_dtype(original_dtype)?;
325        }
326        Ok(res)
327    }
328}
329
330#[derive(Clone)]
331struct Mlp {
332    gate_up_proj: Arc<dyn QuantMethod>,
333    down_proj: Arc<dyn QuantMethod>,
334    act_fn: Activation,
335    i_size: usize,
336    params: Vec<usize>,
337}
338
339impl Mlp {
340    fn new(cfg: &Config, vb: ShardedVarBuilder) -> Result<Self> {
341        let hidden_size = cfg.hidden_size;
342        let i_size = cfg.intermediate_size;
343
344        // No TP here.
345        let gate_up_proj = mistralrs_quant::linear_no_bias(
346            hidden_size,
347            2 * i_size,
348            &cfg.quantization_config,
349            vb.pp("gate_up_proj"),
350        )?;
351
352        let down_proj = mistralrs_quant::linear_no_bias(
353            i_size,
354            hidden_size,
355            &cfg.quantization_config,
356            vb.pp("down_proj"),
357        )?;
358
359        Ok(Self {
360            gate_up_proj,
361            down_proj,
362            act_fn: cfg.hidden_act,
363            i_size,
364            params: vec![hidden_size, i_size],
365        })
366    }
367}
368
369impl AnyMoeTrainableLayer for Mlp {}
370
371impl MlpLayer for Mlp {
372    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
373        let original_dtype = xs.dtype();
374        let mut xs = xs.clone();
375        if let Some(t) = self.gate_up_proj.quantized_act_type() {
376            xs = xs.to_dtype(t)?;
377        }
378        let up_states = MatMul.qmethod_matmul(&xs, &*self.gate_up_proj)?;
379        let gate = up_states.narrow(D::Minus1, 0, self.i_size)?;
380        let up_states = up_states.narrow(D::Minus1, self.i_size, self.i_size)?;
381        let up_states = (up_states * gate.apply(&self.act_fn))?;
382        let mut res = MatMul.qmethod_matmul(&up_states, &*self.down_proj)?;
383        if self.gate_up_proj.quantized_act_type().is_some() {
384            res = res.to_dtype(original_dtype)?;
385        }
386        Ok(res)
387    }
388    fn get_isq_layers(&mut self) -> Vec<&mut Arc<dyn QuantMethod>> {
389        vec![&mut self.gate_up_proj, &mut self.down_proj]
390    }
391    fn clone(&self) -> Box<dyn MlpLayer> {
392        Box::new(Clone::clone(self))
393    }
394    fn get_params(&self) -> &[usize] {
395        &self.params
396    }
397    fn hidden_act(&self) -> Activation {
398        self.act_fn
399    }
400    // gate_up, down
401    fn new_added_delta(&self, deltas: Vec<Option<Tensor>>) -> Result<Box<dyn MlpLayer>> {
402        let new_gate_up = if let Some(ref delta) = deltas[0] {
403            self.gate_up_proj.add_delta_w(delta)?
404        } else {
405            self.gate_up_proj.clone()
406        };
407        let new_down = if let Some(ref delta) = deltas[1] {
408            self.down_proj.add_delta_w(delta)?
409        } else {
410            self.down_proj.clone()
411        };
412
413        Ok(Box::new(Self {
414            gate_up_proj: new_gate_up,
415            down_proj: new_down,
416            act_fn: self.act_fn,
417            i_size: self.i_size,
418            params: self.params.clone(),
419        }))
420    }
421
422    fn dtype_device(&self) -> (DType, Device) {
423        self.gate_up_proj.dtype_and_device()
424    }
425}
426
427struct DecoderLayer {
428    self_attn: Attention,
429    mlp: Box<dyn MlpLayer>,
430    input_layernorm: RmsNorm,
431    post_attention_layernorm: RmsNorm,
432}
433
434impl DecoderLayer {
435    fn new(
436        rotary_emb: Arc<PhiRotaryEmbedding>,
437        cfg: &Config,
438        vb: ShardedVarBuilder,
439        mapper: &dyn DeviceMapper,
440        layer_idx: usize,
441        loading_isq: bool,
442        paged_attn: Option<PagedAttention>,
443    ) -> Result<Self> {
444        let self_attn = Attention::new(
445            rotary_emb,
446            cfg,
447            mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
448            paged_attn,
449        )?;
450        let mlp = Mlp::new(cfg, mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq))?;
451        let input_layernorm = RmsNorm::new(
452            cfg.hidden_size,
453            cfg.rms_norm_eps,
454            mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
455        )?;
456        let post_attention_layernorm = RmsNorm::new(
457            cfg.hidden_size,
458            cfg.rms_norm_eps,
459            mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
460        )?;
461        Ok(Self {
462            self_attn,
463            mlp: Box::new(mlp),
464            input_layernorm,
465            post_attention_layernorm,
466        })
467    }
468
469    #[allow(clippy::too_many_arguments)]
470    fn forward(
471        &self,
472        xs: &Tensor,
473        attention_mask: Option<&Tensor>,
474        seqlen_offsets: &[usize],
475        position_ids: &[usize],
476        kv_cache: &mut KvCache,
477        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
478        flash_params: &FlashParams,
479    ) -> Result<Tensor> {
480        let residual = xs;
481        let xs = self.input_layernorm.forward(xs)?;
482        let xs = self
483            .self_attn
484            .forward(
485                &xs,
486                attention_mask,
487                seqlen_offsets,
488                position_ids,
489                kv_cache,
490                metadata,
491                flash_params,
492            )
493            .unwrap();
494        let xs = (xs + residual)?;
495        let residual = &xs;
496        let xs = self
497            .mlp
498            .forward(&xs.apply(&self.post_attention_layernorm)?)?;
499        residual + xs
500    }
501}
502
503// =================== ============= ===================
504
505// =================== VISION LAYERS ===================
506
507const MAX_INPUT_ID: f64 = 1e9;
508
509#[derive(Debug)]
510struct EmbeddingLayers(Vec<Box<dyn ModuleWithMetadata>>);
511
512impl Module for EmbeddingLayers {
513    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
514        let mut xs = xs.clone();
515        for layer in &self.0 {
516            xs = layer.forward(&xs)?;
517        }
518        Ok(xs)
519    }
520}
521
522#[derive(Debug)]
523pub struct ImageEmbedding {
524    wte: candle_nn::Embedding,
525    image_dim_out: usize,
526    num_img_tokens: usize,
527    glb_gn: Option<Tensor>,
528    sub_gn: Option<Tensor>,
529    layers: EmbeddingLayers,
530    type_feature: String,
531    layer_idx: isize,
532    image_processor: ClipVisionTransformer,
533    hd_transform_order: String,
534    use_hd_transform: bool,
535    vocab_size: usize,
536    tensors: Vec<(String, Tensor)>,
537}
538
539pub(crate) const PHI3V_CLIP_CONFIG: ClipConfig = ClipConfig {
540    hidden_act: clip::Activation::QuickGelu,
541    hidden_size: 1024,
542    image_size: 336,
543    intermediate_size: 4096,
544    num_attention_heads: 16,
545    num_channels: 3,
546    num_hidden_layers: 24,
547    patch_size: 14,
548};
549
550impl ImageEmbedding {
551    fn new(
552        config: &Config,
553        wte: candle_nn::Embedding,
554        embed_config: &EmbedLayerConfig,
555        vb: ShardedVarBuilder,
556    ) -> Result<Self> {
557        let hidden_size = config.hidden_size;
558        if config.img_processor.name != "clip_vision_model" {
559            candle_core::bail!(
560                "img_processor=`{}` nor supported.",
561                config.img_processor.name
562            );
563        }
564        let image_dim_out = config.img_processor.image_dim_out;
565        let num_img_tokens = config.img_processor.num_img_tokens;
566
567        // CLIP image processor here...
568        let image_processor =
569            ClipVisionTransformer::new(vb.pp("img_processor.vision_model"), &PHI3V_CLIP_CONFIG)?;
570
571        // High dim transform
572        let use_hd_transform = embed_config.use_hd_transform.unwrap_or(false);
573        let with_learnable_separator = embed_config.with_learnable_separator.unwrap_or(false);
574        let hd_transform_order = embed_config
575            .hd_transform_order
576            .clone()
577            .unwrap_or("glb_sub".to_string());
578        assert_eq!(use_hd_transform, with_learnable_separator);
579        let (glb_gn, sub_gn) = if with_learnable_separator {
580            let glb_gn = vb.get((1, 1, image_dim_out * 4), "glb_GN")?;
581            let sub_gn = vb.get((1, 1, 1, image_dim_out * 4), "sub_GN")?;
582            (Some(glb_gn), Some(sub_gn))
583        } else {
584            (None, None)
585        };
586
587        // Inner projection
588        let projection_cls = embed_config
589            .projection_cls
590            .clone()
591            .unwrap_or("linear".to_string());
592
593        let mut tensors = Vec::new();
594        let layers: Vec<Box<dyn ModuleWithMetadata>> =
595            match (projection_cls.as_str(), use_hd_transform) {
596                ("linear", _) => {
597                    let a = mistralrs_quant::linear_b(
598                        image_dim_out,
599                        hidden_size,
600                        true,
601                        &None,
602                        vb.pp("img_projection"),
603                    )?;
604                    let (a_w, a_b) = a.unquant_weight_bias().unwrap();
605                    tensors.push(("img_projection.weight".to_string(), a_w));
606                    if let Some(b) = a_b {
607                        tensors.push(("img_projection.bias".to_string(), b));
608                    }
609                    vec![Box::new(QuantMethodWrapper(a))]
610                }
611                ("mlp", true) => {
612                    let dim_proj = hidden_size;
613                    let a = mistralrs_quant::linear_b(
614                        image_dim_out * 4,
615                        dim_proj,
616                        true,
617                        &None,
618                        vb.pp("img_projection.0"),
619                    )?;
620                    let (a_w, a_b) = a.unquant_weight_bias().unwrap();
621                    tensors.push(("img_projection.0.weight".to_string(), a_w));
622                    if let Some(b) = a_b {
623                        tensors.push(("img_projection.0.bias".to_string(), b));
624                    }
625                    let b = mistralrs_quant::linear_b(
626                        dim_proj,
627                        dim_proj,
628                        true,
629                        &None,
630                        vb.pp("img_projection.2"),
631                    )?;
632                    let (b_w, b_b) = b.unquant_weight_bias().unwrap();
633                    tensors.push(("img_projection.2.weight".to_string(), b_w));
634                    if let Some(b) = b_b {
635                        tensors.push(("img_projection.2.bias".to_string(), b));
636                    }
637                    vec![
638                        Box::new(QuantMethodWrapper(a)),
639                        Box::new(candle_nn::Activation::Gelu),
640                        Box::new(QuantMethodWrapper(b)),
641                    ]
642                }
643                ("mlp", false) => {
644                    let dim_proj = hidden_size;
645                    let a = mistralrs_quant::linear_b(
646                        image_dim_out,
647                        dim_proj,
648                        true,
649                        &None,
650                        vb.pp("img_projection.0"),
651                    )?;
652                    let (a_w, a_b) = a.unquant_weight_bias().unwrap();
653                    tensors.push(("img_projection.0.weight".to_string(), a_w));
654                    if let Some(b) = a_b {
655                        tensors.push(("img_projection.0.bias".to_string(), b));
656                    }
657                    let b = mistralrs_quant::linear_b(
658                        dim_proj,
659                        dim_proj,
660                        true,
661                        &None,
662                        vb.pp("img_projection.2"),
663                    )?;
664                    let (b_w, b_b) = b.unquant_weight_bias().unwrap();
665                    tensors.push(("img_projection.2.weight".to_string(), b_w));
666                    if let Some(b) = b_b {
667                        tensors.push(("img_projection.2.bias".to_string(), b));
668                    }
669                    vec![
670                        Box::new(QuantMethodWrapper(a)),
671                        Box::new(candle_nn::Activation::Gelu),
672                        Box::new(QuantMethodWrapper(b)),
673                    ]
674                }
675                _ => {
676                    candle_core::bail!("projection_cls=`{projection_cls}` not implemented.");
677                }
678            };
679
680        let layer_idx = config.img_processor.layer_idx.unwrap_or(-2);
681        let type_feature = config
682            .img_processor
683            .type_feature
684            .clone()
685            .unwrap_or("patch".to_string());
686
687        Ok(Self {
688            wte,
689            image_dim_out,
690            num_img_tokens,
691            glb_gn,
692            sub_gn,
693            layer_idx,
694            type_feature,
695            image_processor,
696            layers: EmbeddingLayers(layers),
697            hd_transform_order,
698            use_hd_transform,
699            vocab_size: config.vocab_size,
700            tensors,
701        })
702    }
703
704    fn get_image_features(&self, pixel_values: &Tensor) -> Result<Tensor> {
705        let hidden_states = self
706            .image_processor
707            .forward_get_hidden_states(&pixel_values.to_dtype(self.wte.embeddings().dtype())?)?;
708        let img_feature =
709            hidden_states[(hidden_states.len() as isize + self.layer_idx) as usize].clone();
710        if self.type_feature == "patch" {
711            img_feature.i((.., 1..))
712        } else if self.type_feature == "cls_patch" {
713            Ok(img_feature)
714        } else {
715            candle_core::bail!("Unsupported image feature type {}", self.type_feature)
716        }
717    }
718
719    #[allow(non_snake_case)]
720    fn forward(
721        &self,
722        input_ids: &Tensor,
723        pixel_values: &Tensor,
724        image_sizes: Option<Vec<(usize, usize)>>,
725    ) -> Result<Tensor> {
726        let input_ids = input_ids.reshape(((), input_ids.dim(D::Minus1)?))?;
727
728        let input_ids_lt = input_ids.lt(0.0f64)?;
729        let input_ids_gt = input_ids.gt(-MAX_INPUT_ID)?;
730        // positions = torch.nonzero((input_ids < 0) & (input_ids > -MAX_INPUT_ID), as_tuple=False)
731        let positions = input_ids_lt.bitwise_and(&input_ids_gt)?.nonzero()?;
732        let target_dev = self.layers.0[0].device();
733        let target_dtype = self.layers.0[0].dtype();
734
735        let mut select = false;
736        // If some, use hd transform case and it contains num_img_toks
737        let mut hd_transform = None;
738        let mut image_set_tensor = None;
739        if positions.dim(0)? > 0 {
740            select = true;
741            // input_ids[positions[:, 0], positions[:, 1]]
742            if self.use_hd_transform && image_sizes.is_some() {
743                assert_eq!(pixel_values.dims().len(), 5);
744                let bs = pixel_values.dim(0)?;
745                let img_features = self.get_image_features(&pixel_values.flatten(0, 1)?)?;
746                let base_feat_dim = (img_features.dims()[1] as f32).sqrt() as usize;
747                assert_eq!(base_feat_dim, 24);
748
749                // bs x max_num_crops x (24x24) x C
750                let img_features =
751                    img_features.reshape((bs, (), base_feat_dim.pow(2), self.image_dim_out))?;
752                let C = self.image_dim_out;
753                let H = base_feat_dim;
754
755                let mut output_imgs = Vec::new();
756                let mut output_len = Vec::new();
757                for bs_ in 0..bs {
758                    let (h, w) = image_sizes.as_ref().unwrap()[bs_];
759                    let h = h / 336;
760                    let w = w / 336;
761                    let B_ = h * w;
762
763                    // 1 x (24x24) x 1024
764                    let global_img_feature = img_features.i((bs_, ..1))?;
765
766                    // 1 x 12 x 12 x 4096
767                    let glb_img = global_img_feature
768                        .reshape((1, H, H, C))?
769                        .reshape((1, H / 2, 2, H / 2, 2, C))?
770                        .contiguous()?
771                        .permute((0, 1, 3, 2, 4, 5))?
772                        .reshape((1, H / 2, H / 2, 4 * C))?
773                        .contiguous()?;
774                    let temp_glbl_gn = self
775                        .sub_gn
776                        .as_ref()
777                        .expect("Need `sub_gn` if `use_hd_transform`")
778                        .repeat((1, H / 2, 1, 1))?;
779
780                    // 1 x 156 x 4096
781                    let glb_img =
782                        Tensor::cat(&[glb_img, temp_glbl_gn], 2)?.reshape((1, (), 4 * C))?;
783
784                    // (max_num_crops-1) x (12x12) x C
785                    let sub_img = img_features.i((bs_, 1..))?;
786
787                    // 16x574x1024
788                    // Get rid of padding sub_img
789                    let sub_img = sub_img.i(..B_)?;
790
791                    // (num_crops, 12, 2, 12, 2, 1024) -> (num_crops, 12, 12, 2, 2, 1024) -> (num_crops, 12*12, 4*1024)
792                    let sub_img = sub_img
793                        .reshape((B_, H, H, C))?
794                        .reshape((B_, H / 2, 2, H / 2, 2, C))?
795                        .contiguous()?
796                        .permute((0, 1, 3, 2, 4, 5))?
797                        .reshape((B_, (), 4 * C))?
798                        .contiguous()?;
799                    let sub_img = sub_img
800                        .reshape(BigShapeWithOneHole((1usize, h, w, 12usize, 12usize, ())))?
801                        .permute((0, 1, 3, 2, 4, 5))?
802                        .reshape((1, h * 12, w * 12, 4 * C))?;
803                    let temp_sub_gn = self
804                        .sub_gn
805                        .as_ref()
806                        .expect("Need `sub_gn` if `use_hd_transform`")
807                        .repeat((1, h * 12, 1, 1))?;
808
809                    let sub_img =
810                        Tensor::cat(&[sub_img, temp_sub_gn], 2)?.reshape((1, (), 4 * C))?;
811
812                    // (1, num_img_tokens, 1024*4)
813
814                    match self.hd_transform_order.as_str() {
815                        "glb_sub" => {
816                            output_imgs.push(Tensor::cat(
817                                &[
818                                    glb_img,
819                                    self.glb_gn
820                                        .as_ref()
821                                        .expect("Need `glb_gn` if `use_hd_transform`")
822                                        .clone(),
823                                    sub_img,
824                                ],
825                                1,
826                            )?);
827                        }
828                        "sub_glb" => {
829                            output_imgs.push(Tensor::cat(
830                                &[
831                                    sub_img,
832                                    self.glb_gn
833                                        .as_ref()
834                                        .expect("Need `glb_gn` if `use_hd_transform`")
835                                        .clone(),
836                                    glb_img,
837                                ],
838                                1,
839                            )?);
840                        }
841                        other => {
842                            candle_core::bail!("Invalid hd_transform_order=`{other}`");
843                        }
844                    }
845
846                    let temp_len = (h * w + 1) * 144 + 1 + (h + 1) * 12;
847                    assert_eq!(temp_len, output_imgs.last().unwrap().dims()[1]);
848                    output_len.push(temp_len);
849                }
850
851                hd_transform = Some(output_len);
852                let mut image_set_tensor_inner = Vec::new();
853                for img in output_imgs {
854                    let layerout = self
855                        .layers
856                        .forward(&img.to_device(&target_dev)?.to_dtype(target_dtype)?)?;
857                    image_set_tensor_inner.push(layerout);
858                }
859                image_set_tensor = Some(Either::Left(image_set_tensor_inner));
860            } else if pixel_values.dims().len() == 4 {
861                let tt = self
862                    .get_image_features(pixel_values)?
863                    .to_device(&target_dev)?
864                    .to_dtype(target_dtype)?
865                    .reshape(((), self.image_dim_out))?;
866                let image_set_tensor_inner = self.layers.forward(&tt)?;
867                image_set_tensor = Some(Either::Right(image_set_tensor_inner));
868            } else if pixel_values.dims().len() == 3 {
869                let tt = pixel_values
870                    .to_device(&target_dev)?
871                    .to_dtype(target_dtype)?
872                    .reshape(((), self.image_dim_out))?;
873                let image_set_tensor_inner = self.layers.forward(&tt)?;
874                image_set_tensor = Some(Either::Right(image_set_tensor_inner));
875            } else {
876                unreachable!()
877            }
878        }
879
880        let input_ids = input_ids.clamp(0.0, self.vocab_size as f64)?;
881        let mut hidden_states = self.wte.forward(&input_ids)?;
882        if select {
883            match (hd_transform, image_set_tensor) {
884                (Some(output_lens), Some(Either::Left(image_set_tensors))) => {
885                    let mut idx = 0;
886                    for (i, cnt) in output_lens.into_iter().enumerate() {
887                        let img_set_tensor = image_set_tensors[i]
888                            .to_device(&target_dev)?
889                            .to_dtype(target_dtype)?;
890                        // hidden_states[positions[idx, 0], positions[idx, 1] : positions[idx, 1] + cnt] = ...
891                        let p_0 = positions.i((idx, 0))?.to_scalar::<u32>()? as usize;
892                        let p_1 = positions.i((idx, 1))?.to_scalar::<u32>()? as usize;
893                        hidden_states = hidden_states.slice_assign(
894                            &[&p_0, &(p_1..p_1 + cnt), &(..img_set_tensor.dims()[2])],
895                            &img_set_tensor,
896                        )?;
897                        idx += cnt;
898                    }
899                }
900                (None, Some(Either::Right(image_set_tensor))) => {
901                    let mut idx = 0;
902                    // Know len(img_embeds) == pixel_values.dim(0) == len(selected_g_values)
903                    // https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/dbcdaaacf52c8e40cf8de6d6ffa6ff6860e5f256/image_embedding_phi3_v.py#L259
904                    for i in 0..pixel_values.dim(0)? {
905                        let cnt = self.num_img_tokens;
906                        let img_set_tensor = image_set_tensor
907                            .i(i * cnt..(i + 1) * cnt)?
908                            .to_device(&target_dev)?
909                            .to_dtype(target_dtype)?;
910                        let p_0 = positions.i((idx, 0))?.to_scalar::<u32>()? as usize;
911                        let p_1 = positions.i((idx, 1))?.to_scalar::<u32>()? as usize;
912                        // hidden_states[positions[idx, 0], positions[idx, 1] : positions[idx, 1] + cnt] = ...
913                        hidden_states = hidden_states.slice_assign(
914                            &[&p_0, &(p_1..p_1 + cnt), &(..img_set_tensor.dims()[2])],
915                            &img_set_tensor,
916                        )?;
917                        idx += cnt;
918                    }
919                }
920                _ => unreachable!(),
921            }
922        }
923
924        Ok(hidden_states)
925    }
926
927    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
928        let uvb = UnVarBuilder::new();
929
930        if let Some(glb_gn) = self.glb_gn.clone() {
931            uvb.add_tensor("glb_GN", glb_gn);
932        }
933        if let Some(sub_gn) = self.sub_gn.clone() {
934            uvb.add_tensor("sub_GN", sub_gn);
935        }
936        uvb.extend(self.tensors.clone());
937        uvb.pp("img_processor.vision_model")
938            .extend(self.image_processor.residual_tensors());
939
940        uvb.to_safetensors()
941    }
942}
943
944// =================== ============= ===================
945
946pub struct Model {
947    vision_embed_tokens: ImageEmbedding,
948    embed_tokens: candle_nn::Embedding,
949    layers: Vec<DecoderLayer>,
950    norm: RmsNorm,
951    lm_head: Arc<dyn QuantMethod>,
952    device: Device,
953    cache: EitherCache,
954    max_seq_len: usize,
955    mapper: Box<dyn DeviceMapper + Send + Sync>,
956    sliding_window: Option<usize>,
957    cfg: ModelConfigMetadata,
958}
959
960impl Model {
961    pub fn new(
962        cfg: &Config,
963        vb: ShardedVarBuilder,
964        _is_gptx: bool,
965        normal_loading_metadata: NormalLoadingMetadata,
966        attention_mechanism: AttentionImplementation,
967    ) -> Result<Self> {
968        let mapper = normal_loading_metadata.mapper;
969        let vb_m = vb.pp("model");
970
971        let embed_tokens = layers::embedding(
972            cfg.vocab_size,
973            cfg.hidden_size,
974            mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
975            &cfg.quantization_config,
976        )?;
977        let vision_embed_tokens = ImageEmbedding::new(
978            cfg,
979            embed_tokens.clone(),
980            &cfg.embd_layer,
981            mapper.set_nm_device(vb_m.pp("vision_embed_tokens"), false),
982        )?;
983        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
984        let vb_l = vb_m.pp("layers");
985        let mut ropes = HashMap::new();
986        for layer_idx in 0..cfg.num_hidden_layers {
987            let device = mapper
988                .device_for(layer_idx, false)
989                .unwrap_or(&normal_loading_metadata.real_device);
990            ropes.insert(
991                device.location(),
992                Arc::new(PhiRotaryEmbedding::new(vb.dtype(), cfg.clone(), device)?),
993            );
994        }
995        for layer_idx in NiceProgressBar::<_, 'b'>(
996            0..cfg.num_hidden_layers,
997            "Loading repeating layers",
998            &normal_loading_metadata.multi_progress,
999        ) {
1000            let device = mapper
1001                .device_for(layer_idx, false)
1002                .unwrap_or(&normal_loading_metadata.real_device);
1003            let rotary_emb = ropes
1004                .get(&device.location())
1005                .expect("No RoPE for device location!")
1006                .clone();
1007            let paged_attn = match &attention_mechanism {
1008                AttentionImplementation::Eager => None,
1009                AttentionImplementation::PagedAttention => {
1010                    Some(PagedAttention::new(cfg.head_dim(), device, None)?)
1011                }
1012            };
1013            let layer = DecoderLayer::new(
1014                rotary_emb.clone(),
1015                cfg,
1016                vb_l.pp(layer_idx),
1017                &*mapper,
1018                layer_idx,
1019                normal_loading_metadata.loading_isq,
1020                paged_attn,
1021            )?;
1022            layers.push(layer)
1023        }
1024        let norm = RmsNorm::new(
1025            cfg.hidden_size,
1026            cfg.rms_norm_eps,
1027            mapper.set_nm_device(vb_m.pp("norm"), false),
1028        )?;
1029        let lm_head = if !cfg.tie_word_embeddings {
1030            ReplicatedLayer::new(
1031                cfg.hidden_size,
1032                cfg.vocab_size,
1033                &None,
1034                false,
1035                mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
1036            )?
1037        } else {
1038            ReplicatedLayer::from_linear(candle_nn::Linear::new(
1039                mapper.cast_nm_device(
1040                    embed_tokens.embeddings(),
1041                    normal_loading_metadata.loading_isq,
1042                )?,
1043                None,
1044            ))?
1045        };
1046
1047        Ok(Self {
1048            vision_embed_tokens,
1049            layers,
1050            norm,
1051            lm_head,
1052            device: normal_loading_metadata.real_device,
1053            cache: EitherCache::Normal(NormalCache::new_sliding(
1054                cfg.num_hidden_layers,
1055                cfg.max_position_embeddings,
1056                cfg.sliding_window,
1057            )),
1058            max_seq_len: cfg.max_position_embeddings,
1059            sliding_window: cfg.sliding_window,
1060            embed_tokens,
1061            cfg: ModelConfigMetadata {
1062                max_seq_len: cfg.max_position_embeddings,
1063                num_layers: cfg.num_hidden_layers,
1064                hidden_size: cfg.hidden_size,
1065                num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
1066                num_kv_heads: (cfg.num_key_value_heads / mapper.get_comm_for(0)?.world_size())
1067                    .max(1),
1068                sliding_window: cfg.sliding_window,
1069                k_head_dim: cfg.head_dim(),
1070                v_head_dim: cfg.head_dim(),
1071            },
1072            mapper,
1073        })
1074    }
1075
1076    #[allow(clippy::too_many_arguments)]
1077    pub fn forward(
1078        &self,
1079        input_ids: &Tensor,
1080        pixel_values: Option<Tensor>,
1081        seqlen_offsets: &[usize],
1082        position_ids: &[usize],
1083        context_lens: Vec<(usize, usize)>,
1084        image_sizes: Option<Vec<(usize, usize)>>,
1085        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
1086        flash_params: &FlashParams,
1087    ) -> Result<Tensor> {
1088        let mut xs = if let Some(ref pixel_values) = pixel_values {
1089            self.vision_embed_tokens
1090                .forward(input_ids, pixel_values, image_sizes)?
1091        } else {
1092            self.embed_tokens.forward(input_ids)?
1093        };
1094        let cache = &mut self.cache.normal().0;
1095        let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
1096            input_ids,
1097            metadata
1098                .as_ref()
1099                .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
1100                .unwrap_or(&*cache as &dyn PastKvLenCache),
1101            self.sliding_window,
1102            xs.dtype(),
1103            self.cfg.num_attn_heads,
1104        )?;
1105        let attention_mask = attention_mask.filter(|_| {
1106            metadata
1107                .as_ref()
1108                .map(|(_, meta)| meta.is_first_prompt_chunk)
1109                .unwrap_or(true)
1110        });
1111
1112        for (i, layer) in self.layers.iter().enumerate() {
1113            xs = self.mapper.map(xs, i)?;
1114            xs = layer.forward(
1115                &xs,
1116                attention_mask
1117                    .as_ref()
1118                    .map(|m| m.to_device(xs.device()).unwrap())
1119                    .as_ref(),
1120                seqlen_offsets,
1121                position_ids,
1122                &mut cache[i],
1123                metadata
1124                    .as_ref()
1125                    .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
1126                flash_params,
1127            )?
1128        }
1129        let xs = xs.to_device(&self.device)?;
1130        let mut xs = xs.apply(&self.norm)?;
1131        if let Some(t) = self.lm_head.quantized_act_type() {
1132            xs = xs.to_dtype(t)?;
1133        }
1134        extract_logits(&MatMul.qmethod_matmul(&xs, &*self.lm_head)?, context_lens)
1135    }
1136}
1137
1138impl IsqModel for Model {
1139    fn get_layers(
1140        &mut self,
1141    ) -> (
1142        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
1143        &dyn DeviceMapper,
1144    ) {
1145        let mut tensors = Vec::new();
1146        tensors.push((&mut self.lm_head, None));
1147        for (i, layer) in self.layers.iter_mut().enumerate() {
1148            tensors.push((&mut layer.self_attn.qkv_proj, Some(i)));
1149            tensors.push((&mut layer.self_attn.o_proj, Some(i)));
1150            tensors.extend(
1151                layer
1152                    .mlp
1153                    .get_isq_layers()
1154                    .into_iter()
1155                    .map(|m| (m, Some(i)))
1156                    .collect::<Vec<_>>(),
1157            );
1158        }
1159        (tensors, &*self.mapper)
1160    }
1161
1162    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
1163        let uvb = UnVarBuilder::new();
1164
1165        let uvb_m = uvb.pp("model");
1166        uvb_m.pp("embed_tokens").add(&self.embed_tokens);
1167        uvb_m.pp("norm").add(&self.norm);
1168        uvb_m
1169            .pp("vision_embed_tokens")
1170            .extend(self.vision_embed_tokens.residual_tensors());
1171
1172        for (layer_idx, layer) in self.layers.iter().enumerate() {
1173            let uvb_l = uvb_m.pp("layers").pp(layer_idx);
1174            uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
1175            uvb_l
1176                .pp("post_attention_layernorm")
1177                .add(&layer.post_attention_layernorm);
1178        }
1179
1180        uvb.to_safetensors()
1181    }
1182}
1183
1184#[derive(Default)]
1185pub(crate) struct Phi3VisionSpecificArgs {
1186    pub image_sizes: Option<Vec<(usize, usize)>>,
1187}
1188
1189impl VisionModel for Model {
1190    fn forward(
1191        &self,
1192        input_ids: &Tensor,
1193        pixel_values: Option<Tensor>,
1194        seqlen_offsets: &[usize],
1195        context_lens: Vec<(usize, usize)>,
1196        position_ids: Vec<usize>,
1197        model_specific_args: Box<dyn Any>,
1198        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
1199        flash_params: &FlashParams,
1200    ) -> Result<Tensor> {
1201        let Phi3VisionSpecificArgs { image_sizes } = *model_specific_args
1202            .downcast()
1203            .expect("Cannot downcast into `Phi3VisionSpecificArgs`");
1204        self.forward(
1205            input_ids,
1206            pixel_values,
1207            seqlen_offsets,
1208            &position_ids,
1209            context_lens,
1210            image_sizes,
1211            metadata,
1212            flash_params,
1213        )
1214    }
1215    fn cache(&self) -> &EitherCache {
1216        &self.cache
1217    }
1218    fn cache_mut(&mut self) -> &mut EitherCache {
1219        &mut self.cache
1220    }
1221    fn device(&self) -> &Device {
1222        &self.device
1223    }
1224    fn max_seq_len(&self) -> usize {
1225        self.max_seq_len
1226    }
1227    fn has_conv2d(&self) -> bool {
1228        true
1229    }
1230    fn config(&self) -> &ModelConfigMetadata {
1231        &self.cfg
1232    }
1233    fn default_model_specific_args(&self, _input_ids: &Tensor) -> Box<dyn Any> {
1234        Box::new(Phi3VisionSpecificArgs::default())
1235    }
1236}
1237
1238impl AnyMoeBaseModelMixin for Model {
1239    fn get_mlps(&self) -> Vec<&dyn MlpLayer> {
1240        let mut mlps = Vec::new();
1241        for layer in &self.layers {
1242            mlps.push(&*layer.mlp);
1243        }
1244        mlps
1245    }
1246    fn get_mlps_mut(&mut self) -> Vec<&mut Box<dyn MlpLayer>> {
1247        let mut mlps = Vec::new();
1248        for layer in &mut self.layers {
1249            mlps.push(&mut layer.mlp);
1250        }
1251        mlps
1252    }
1253    fn create_anymoe_layers(
1254        &mut self,
1255        additional_vbs: Vec<ShardedVarBuilder>,
1256        config: AnyMoeConfig,
1257        (prefix, mlp): (String, String),
1258        mut layers: Vec<usize>,
1259        expert_type: AnyMoeExpertType,
1260        gate_vb: Option<ShardedVarBuilder>,
1261    ) -> Result<()> {
1262        let mut experts: Vec<Vec<Box<dyn MlpLayer>>> = Vec::new();
1263        if layers.is_empty() {
1264            layers = (0..self.layers.len()).collect::<Vec<_>>();
1265        }
1266        for _ in 0..layers.len() {
1267            experts.push(Vec::new());
1268        }
1269        for vb in additional_vbs {
1270            let vb = vb.pp(&prefix);
1271            for (layer, row) in experts.iter_mut().enumerate() {
1272                if !layers.contains(&layer) {
1273                    continue;
1274                }
1275
1276                let intermediate_size = self.layers[layer].mlp.get_params()[1];
1277                let hidden_size = self.layers[layer].mlp.get_params()[0];
1278                match expert_type {
1279                    AnyMoeExpertType::FineTuned => {
1280                        row.push(Box::new(Mlp::new(
1281                            &Config {
1282                                intermediate_size: self.layers[layer].mlp.get_params()[1],
1283                                hidden_size: self.layers[layer].mlp.get_params()[0],
1284                                ..Default::default()
1285                            },
1286                            vb.pp(layer).pp(&mlp),
1287                        )?));
1288                    }
1289                    AnyMoeExpertType::LoraAdapter {
1290                        rank,
1291                        alpha,
1292                        ref target_modules,
1293                    } => {
1294                        let vb_mlp = vb.pp(layer).pp(&mlp);
1295
1296                        let gate_up_proj_delta =
1297                            if target_modules.contains(&"gate_up_proj".to_string()) {
1298                                Some(get_delta_from_lora_ab!(
1299                                    vb_mlp,
1300                                    rank,
1301                                    alpha,
1302                                    (hidden_size, 2 * intermediate_size),
1303                                    "gate_up_proj"
1304                                ))
1305                            } else {
1306                                None
1307                            };
1308                        let down_proj_delta = if target_modules.contains(&"down_proj".to_string()) {
1309                            Some(get_delta_from_lora_ab!(
1310                                vb_mlp,
1311                                rank,
1312                                alpha,
1313                                (hidden_size, intermediate_size),
1314                                "down_proj"
1315                            ))
1316                        } else {
1317                            None
1318                        };
1319
1320                        row.push(
1321                            self.layers[layer]
1322                                .mlp
1323                                .new_added_delta(vec![gate_up_proj_delta, down_proj_delta])?,
1324                        );
1325                    }
1326                }
1327            }
1328        }
1329        for (layer, expert) in layers.into_iter().zip(experts) {
1330            let mut experts_all = vec![self.layers[layer].mlp.clone()];
1331            experts_all.extend(expert);
1332            let (dtype, device) = self.layers[layer].mlp.dtype_device();
1333            self.layers[layer].mlp = Box::new(MoeMlp::new(
1334                experts_all,
1335                config.clone(),
1336                dtype,
1337                &device,
1338                layer,
1339                gate_vb.as_ref(),
1340            )?);
1341        }
1342        Ok(())
1343    }
1344    fn amoe_supported(&self) -> bool {
1345        true
1346    }
1347}