mistralrs_core/vision_models/phi3/
mod.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3pub(crate) mod phi3_inputs_processor;
4
5// This implementation is based on:
6// https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/modeling_phi3.py
7use candle_core::{
8    shape::ShapeWithOneHole, DType, Device, IndexOp, Module, Result, Shape, Tensor, D,
9};
10use either::Either;
11use mistralrs_quant::{
12    BitWiseOp, NonZeroOp, QuantMethod, QuantizedConfig, ReplicatedLayer, ShardedVarBuilder,
13};
14use std::{any::Any, collections::HashMap, fmt::Debug, sync::Arc};
15
16use crate::{
17    amoe::{AnyMoeBaseModelMixin, AnyMoeTrainableLayer, MlpLayer, MoeMlp},
18    attention::SdpaParams,
19    device_map::DeviceMapper,
20    get_delta_from_lora_ab,
21    layers::{
22        self, Activation, CausalMasker, MatMul, PhiRopeConfig, PhiRopeScalingConfig,
23        PhiRotaryEmbedding, RmsNorm, Sdpa,
24    },
25    layers_masker::PastKvLenCache,
26    paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
27    pipeline::{
28        extract_logits,
29        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
30        EitherCache, IsqModel, KvCache, NormalCache, NormalLoadingMetadata, VisionModel,
31    },
32    serde_default_fn,
33    utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
34    vision_models::clip::{ClipConfig, ClipVisionTransformer},
35    AnyMoeConfig, AnyMoeExpertType,
36};
37
38use super::clip;
39
40#[derive(Debug, Clone, serde::Deserialize, Default)]
41pub struct EmbedLayerConfig {
42    pub hd_transform_order: Option<String>,
43    pub projection_cls: Option<String>,
44    pub use_hd_transform: Option<bool>,
45    pub with_learnable_separator: Option<bool>,
46}
47
48#[derive(Debug, Clone, serde::Deserialize, Default)]
49pub struct ImageProcessorConfig {
50    pub image_dim_out: usize,
51    pub name: String,
52    pub num_img_tokens: usize,
53    pub layer_idx: Option<isize>,
54    pub type_feature: Option<String>,
55}
56
57serde_default_fn!(bool, word_emb_default, false);
58
59#[derive(Debug, Clone, serde::Deserialize, Default)]
60pub struct Config {
61    pub vocab_size: usize,
62    pub hidden_act: Activation,
63    pub hidden_size: usize,
64    pub intermediate_size: usize,
65    pub num_hidden_layers: usize,
66    pub num_attention_heads: usize,
67    pub num_key_value_heads: usize,
68    pub rms_norm_eps: f64,
69    pub rope_theta: f64,
70    pub bos_token_id: Option<u32>,
71    pub eos_token_id: Option<u32>,
72    pub rope_scaling: Option<PhiRopeScalingConfig>,
73    pub max_position_embeddings: usize,
74    pub sliding_window: Option<usize>,
75    pub original_max_position_embeddings: usize,
76    pub embd_layer: EmbedLayerConfig,
77    pub img_processor: ImageProcessorConfig,
78    #[serde(alias = "quantization")]
79    pub quantization_config: Option<QuantizedConfig>,
80    #[serde(default = "word_emb_default")]
81    pub tie_word_embeddings: bool,
82}
83
84impl From<Config> for PhiRopeConfig {
85    fn from(val: Config) -> Self {
86        PhiRopeConfig {
87            rope_scaling: val.rope_scaling,
88            max_position_embeddings: val.max_position_embeddings,
89            original_max_position_embeddings: val.original_max_position_embeddings,
90            rope_theta: val.rope_theta,
91            head_dim: val.hidden_size / val.num_attention_heads,
92            partial_rotary_factor: None,
93        }
94    }
95}
96
97impl Config {
98    pub fn head_dim(&self) -> usize {
99        self.hidden_size / self.num_attention_heads
100    }
101}
102
103trait ModuleWithMetadata: Module + Debug + Send + Sync {
104    fn device(&self) -> Device;
105    fn dtype(&self) -> DType;
106}
107
108#[derive(Debug)]
109struct QuantMethodWrapper(Arc<dyn QuantMethod>);
110
111impl Module for QuantMethodWrapper {
112    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
113        self.0.forward(xs)
114    }
115}
116
117impl ModuleWithMetadata for QuantMethodWrapper {
118    fn device(&self) -> Device {
119        self.0.unquant_weight_bias().unwrap().0.device().clone()
120    }
121    fn dtype(&self) -> DType {
122        self.0.unquant_weight_bias().unwrap().0.dtype()
123    }
124}
125
126impl ModuleWithMetadata for candle_nn::Activation {
127    fn device(&self) -> Device {
128        unreachable!()
129    }
130    fn dtype(&self) -> DType {
131        unreachable!()
132    }
133}
134
135#[derive(Debug)]
136struct BigShapeWithOneHole((usize, usize, usize, usize, usize, ()));
137
138fn hole_size(el_count: usize, prod_d: usize, s: &dyn std::fmt::Debug) -> Result<usize> {
139    if prod_d == 0 {
140        candle_core::bail!("cannot reshape tensor of {el_count} elements to {s:?}")
141    }
142    if el_count % prod_d != 0 {
143        candle_core::bail!("cannot reshape tensor with {el_count} elements to {s:?}")
144    }
145    Ok(el_count / prod_d)
146}
147
148impl ShapeWithOneHole for BigShapeWithOneHole {
149    fn into_shape(self, el_count: usize) -> Result<Shape> {
150        let (d1, d2, d3, d4, d5, ()) = self.0;
151        let d = hole_size(el_count, d1 * d2 * d3 * d4 * d5, &self)?;
152        Ok((d1, d2, d3, d4, d5, d).into())
153    }
154}
155
156// =================== BASE LAYERS ===================
157
158struct Attention {
159    qkv_proj: Arc<dyn QuantMethod>,
160    o_proj: Arc<dyn QuantMethod>,
161    num_heads: usize,
162    num_kv_heads: usize,
163    head_dim: usize,
164    rotary_emb: Arc<PhiRotaryEmbedding>,
165    paged_attn: Option<PagedAttention>,
166    sdpa_params: SdpaParams,
167}
168
169impl Attention {
170    fn new(
171        rotary_emb: Arc<PhiRotaryEmbedding>,
172        cfg: &Config,
173        vb: ShardedVarBuilder,
174        paged_attn: Option<PagedAttention>,
175    ) -> Result<Self> {
176        let num_heads = cfg.num_attention_heads;
177        let num_kv_heads = cfg.num_key_value_heads;
178        let head_dim = cfg.head_dim();
179        let op_size = num_heads * head_dim + 2 * num_kv_heads * head_dim;
180
181        // No TP here.
182        let qkv_proj = mistralrs_quant::linear_no_bias(
183            cfg.hidden_size,
184            op_size,
185            &cfg.quantization_config,
186            vb.pp("qkv_proj"),
187        )?;
188
189        let o_proj = mistralrs_quant::linear_no_bias(
190            num_heads * head_dim,
191            cfg.hidden_size,
192            &cfg.quantization_config,
193            vb.pp("o_proj"),
194        )?;
195
196        Ok(Self {
197            qkv_proj,
198            o_proj,
199            rotary_emb,
200            num_heads,
201            num_kv_heads,
202            head_dim,
203            paged_attn,
204            sdpa_params: SdpaParams {
205                n_kv_groups: num_heads / num_kv_heads,
206                softcap: None,
207                softmax_scale: 1.0 / (head_dim as f32).sqrt(),
208                sliding_window: cfg.sliding_window,
209            },
210        })
211    }
212
213    #[allow(clippy::too_many_arguments)]
214    fn forward(
215        &self,
216        xs: &Tensor,
217        attention_mask: Option<&Tensor>,
218        seqlen_offsets: &[usize],
219        position_ids: &[usize],
220        kv_cache: &mut KvCache,
221        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
222        flash_params: &FlashParams,
223    ) -> Result<Tensor> {
224        let (b_sz, q_len, _) = xs.dims3()?;
225
226        let original_dtype = xs.dtype();
227        let mut xs = xs.clone();
228        if let Some(t) = self.qkv_proj.quantized_act_type() {
229            xs = xs.to_dtype(t)?;
230        }
231        let mut qkv = MatMul.qmethod_matmul(&xs, &*self.qkv_proj)?;
232        if self.qkv_proj.quantized_act_type().is_some() {
233            qkv = qkv.to_dtype(original_dtype)?;
234        }
235        let query_pos = self.num_heads * self.head_dim;
236        let q = qkv.narrow(D::Minus1, 0, query_pos)?;
237        let k = qkv.narrow(D::Minus1, query_pos, self.num_kv_heads * self.head_dim)?;
238        let v = qkv.narrow(
239            D::Minus1,
240            query_pos + self.num_kv_heads * self.head_dim,
241            self.num_kv_heads * self.head_dim,
242        )?;
243
244        let (q, k, v) = if q_len != 1 {
245            let q = q
246                .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
247                .transpose(1, 2)?;
248            let k = k
249                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
250                .transpose(1, 2)?;
251            let v = v
252                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
253                .transpose(1, 2)?;
254            (q, k, v)
255        } else {
256            let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
257            let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
258            let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
259            (q, k, v)
260        };
261
262        let (q, k) = self
263            .rotary_emb
264            .forward(&q, &k, seqlen_offsets, position_ids)?;
265
266        let mut attn_output = match &self.paged_attn {
267            Some(paged_attn) => match metadata {
268                Some(((key_cache, value_cache), input_metadata)) => paged_attn.forward(
269                    &q,
270                    &k.contiguous()?,
271                    &v.contiguous()?,
272                    attention_mask,
273                    Some(key_cache),
274                    Some(value_cache),
275                    input_metadata,
276                    &self.sdpa_params,
277                    Some(flash_params),
278                )?,
279                None => {
280                    // If we don't have metadata, we are most likely generating an imatrix so we don't want to populate that.
281                    // Generating the dummy metadata with the assumption that we are not generating text (only processing prompts).
282                    let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
283                    // Sanity check.
284                    assert!(attention_mask.is_some());
285                    paged_attn.forward(
286                        &q,
287                        &k.contiguous()?,
288                        &v.contiguous()?,
289                        attention_mask,
290                        None,
291                        None,
292                        &input_metadata,
293                        &self.sdpa_params,
294                        Some(flash_params),
295                    )?
296                }
297            },
298            None => {
299                let (k, v) = kv_cache.append(&k, &v)?;
300
301                Sdpa.run_attention(
302                    &q,
303                    &k,
304                    &v,
305                    attention_mask,
306                    Some(flash_params),
307                    &self.sdpa_params,
308                )?
309            }
310        };
311
312        if let Some(t) = self.qkv_proj.quantized_act_type() {
313            attn_output = attn_output.to_dtype(t)?;
314        }
315        attn_output = if attention_mask.is_some() {
316            attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?
317        } else {
318            attn_output.reshape((b_sz, q_len, ()))?
319        };
320        let mut res = MatMul.qmethod_matmul(&attn_output, &*self.o_proj)?;
321        if self.qkv_proj.quantized_act_type().is_some() {
322            res = res.to_dtype(original_dtype)?;
323        }
324        Ok(res)
325    }
326}
327
328#[derive(Clone)]
329struct Mlp {
330    gate_up_proj: Arc<dyn QuantMethod>,
331    down_proj: Arc<dyn QuantMethod>,
332    act_fn: Activation,
333    i_size: usize,
334    params: Vec<usize>,
335}
336
337impl Mlp {
338    fn new(cfg: &Config, vb: ShardedVarBuilder) -> Result<Self> {
339        let hidden_size = cfg.hidden_size;
340        let i_size = cfg.intermediate_size;
341
342        // No TP here.
343        let gate_up_proj = mistralrs_quant::linear_no_bias(
344            hidden_size,
345            2 * i_size,
346            &cfg.quantization_config,
347            vb.pp("gate_up_proj"),
348        )?;
349
350        let down_proj = mistralrs_quant::linear_no_bias(
351            i_size,
352            hidden_size,
353            &cfg.quantization_config,
354            vb.pp("down_proj"),
355        )?;
356
357        Ok(Self {
358            gate_up_proj,
359            down_proj,
360            act_fn: cfg.hidden_act,
361            i_size,
362            params: vec![hidden_size, i_size],
363        })
364    }
365}
366
367impl AnyMoeTrainableLayer for Mlp {}
368
369impl MlpLayer for Mlp {
370    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
371        let original_dtype = xs.dtype();
372        let mut xs = xs.clone();
373        if let Some(t) = self.gate_up_proj.quantized_act_type() {
374            xs = xs.to_dtype(t)?;
375        }
376        let up_states = MatMul.qmethod_matmul(&xs, &*self.gate_up_proj)?;
377        let gate = up_states.narrow(D::Minus1, 0, self.i_size)?;
378        let up_states = up_states.narrow(D::Minus1, self.i_size, self.i_size)?;
379        let up_states = (up_states * gate.apply(&self.act_fn))?;
380        let mut res = MatMul.qmethod_matmul(&up_states, &*self.down_proj)?;
381        if self.gate_up_proj.quantized_act_type().is_some() {
382            res = res.to_dtype(original_dtype)?;
383        }
384        Ok(res)
385    }
386    fn get_isq_layers(&mut self) -> Vec<&mut Arc<dyn QuantMethod>> {
387        vec![&mut self.gate_up_proj, &mut self.down_proj]
388    }
389    fn clone(&self) -> Box<dyn MlpLayer> {
390        Box::new(Clone::clone(self))
391    }
392    fn get_params(&self) -> &[usize] {
393        &self.params
394    }
395    fn hidden_act(&self) -> Activation {
396        self.act_fn
397    }
398    // gate_up, down
399    fn new_added_delta(&self, deltas: Vec<Option<Tensor>>) -> Result<Box<dyn MlpLayer>> {
400        let new_gate_up = if let Some(ref delta) = deltas[0] {
401            self.gate_up_proj.add_delta_w(delta)?
402        } else {
403            self.gate_up_proj.clone()
404        };
405        let new_down = if let Some(ref delta) = deltas[1] {
406            self.down_proj.add_delta_w(delta)?
407        } else {
408            self.down_proj.clone()
409        };
410
411        Ok(Box::new(Self {
412            gate_up_proj: new_gate_up,
413            down_proj: new_down,
414            act_fn: self.act_fn,
415            i_size: self.i_size,
416            params: self.params.clone(),
417        }))
418    }
419
420    fn dtype_device(&self) -> (DType, Device) {
421        self.gate_up_proj.dtype_and_device()
422    }
423}
424
425struct DecoderLayer {
426    self_attn: Attention,
427    mlp: Box<dyn MlpLayer>,
428    input_layernorm: RmsNorm,
429    post_attention_layernorm: RmsNorm,
430}
431
432impl DecoderLayer {
433    fn new(
434        rotary_emb: Arc<PhiRotaryEmbedding>,
435        cfg: &Config,
436        vb: ShardedVarBuilder,
437        mapper: &dyn DeviceMapper,
438        layer_idx: usize,
439        loading_isq: bool,
440        paged_attn: Option<PagedAttention>,
441    ) -> Result<Self> {
442        let self_attn = Attention::new(
443            rotary_emb,
444            cfg,
445            mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
446            paged_attn,
447        )?;
448        let mlp = Mlp::new(cfg, mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq))?;
449        let input_layernorm = RmsNorm::new(
450            cfg.hidden_size,
451            cfg.rms_norm_eps,
452            mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
453        )?;
454        let post_attention_layernorm = RmsNorm::new(
455            cfg.hidden_size,
456            cfg.rms_norm_eps,
457            mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
458        )?;
459        Ok(Self {
460            self_attn,
461            mlp: Box::new(mlp),
462            input_layernorm,
463            post_attention_layernorm,
464        })
465    }
466
467    #[allow(clippy::too_many_arguments)]
468    fn forward(
469        &self,
470        xs: &Tensor,
471        attention_mask: Option<&Tensor>,
472        seqlen_offsets: &[usize],
473        position_ids: &[usize],
474        kv_cache: &mut KvCache,
475        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
476        flash_params: &FlashParams,
477    ) -> Result<Tensor> {
478        let residual = xs;
479        let xs = self.input_layernorm.forward(xs)?;
480        let xs = self
481            .self_attn
482            .forward(
483                &xs,
484                attention_mask,
485                seqlen_offsets,
486                position_ids,
487                kv_cache,
488                metadata,
489                flash_params,
490            )
491            .unwrap();
492        let xs = (xs + residual)?;
493        let residual = &xs;
494        let xs = self
495            .mlp
496            .forward(&xs.apply(&self.post_attention_layernorm)?)?;
497        residual + xs
498    }
499}
500
501// =================== ============= ===================
502
503// =================== VISION LAYERS ===================
504
505const MAX_INPUT_ID: f64 = 1e9;
506
507#[derive(Debug)]
508struct EmbeddingLayers(Vec<Box<dyn ModuleWithMetadata>>);
509
510impl Module for EmbeddingLayers {
511    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
512        let mut xs = xs.clone();
513        for layer in &self.0 {
514            xs = layer.forward(&xs)?;
515        }
516        Ok(xs)
517    }
518}
519
520#[derive(Debug)]
521pub struct ImageEmbedding {
522    wte: candle_nn::Embedding,
523    image_dim_out: usize,
524    num_img_tokens: usize,
525    glb_gn: Option<Tensor>,
526    sub_gn: Option<Tensor>,
527    layers: EmbeddingLayers,
528    type_feature: String,
529    layer_idx: isize,
530    image_processor: ClipVisionTransformer,
531    hd_transform_order: String,
532    use_hd_transform: bool,
533    vocab_size: usize,
534    tensors: Vec<(String, Tensor)>,
535}
536
537pub(crate) const PHI3V_CLIP_CONFIG: ClipConfig = ClipConfig {
538    hidden_act: clip::Activation::QuickGelu,
539    hidden_size: 1024,
540    image_size: 336,
541    intermediate_size: 4096,
542    num_attention_heads: 16,
543    num_channels: 3,
544    num_hidden_layers: 24,
545    patch_size: 14,
546};
547
548impl ImageEmbedding {
549    fn new(
550        config: &Config,
551        wte: candle_nn::Embedding,
552        embed_config: &EmbedLayerConfig,
553        vb: ShardedVarBuilder,
554    ) -> Result<Self> {
555        let hidden_size = config.hidden_size;
556        if config.img_processor.name != "clip_vision_model" {
557            candle_core::bail!(
558                "img_processor=`{}` nor supported.",
559                config.img_processor.name
560            );
561        }
562        let image_dim_out = config.img_processor.image_dim_out;
563        let num_img_tokens = config.img_processor.num_img_tokens;
564
565        // CLIP image processor here...
566        let image_processor =
567            ClipVisionTransformer::new(vb.pp("img_processor.vision_model"), &PHI3V_CLIP_CONFIG)?;
568
569        // High dim transform
570        let use_hd_transform = embed_config.use_hd_transform.unwrap_or(false);
571        let with_learnable_separator = embed_config.with_learnable_separator.unwrap_or(false);
572        let hd_transform_order = embed_config
573            .hd_transform_order
574            .clone()
575            .unwrap_or("glb_sub".to_string());
576        assert_eq!(use_hd_transform, with_learnable_separator);
577        let (glb_gn, sub_gn) = if with_learnable_separator {
578            let glb_gn = vb.get((1, 1, image_dim_out * 4), "glb_GN")?;
579            let sub_gn = vb.get((1, 1, 1, image_dim_out * 4), "sub_GN")?;
580            (Some(glb_gn), Some(sub_gn))
581        } else {
582            (None, None)
583        };
584
585        // Inner projection
586        let projection_cls = embed_config
587            .projection_cls
588            .clone()
589            .unwrap_or("linear".to_string());
590
591        let mut tensors = Vec::new();
592        let layers: Vec<Box<dyn ModuleWithMetadata>> =
593            match (projection_cls.as_str(), use_hd_transform) {
594                ("linear", _) => {
595                    let a = mistralrs_quant::linear_b(
596                        image_dim_out,
597                        hidden_size,
598                        true,
599                        &None,
600                        vb.pp("img_projection"),
601                    )?;
602                    let (a_w, a_b) = a.unquant_weight_bias().unwrap();
603                    tensors.push(("img_projection.weight".to_string(), a_w));
604                    if let Some(b) = a_b {
605                        tensors.push(("img_projection.bias".to_string(), b));
606                    }
607                    vec![Box::new(QuantMethodWrapper(a))]
608                }
609                ("mlp", true) => {
610                    let dim_proj = hidden_size;
611                    let a = mistralrs_quant::linear_b(
612                        image_dim_out * 4,
613                        dim_proj,
614                        true,
615                        &None,
616                        vb.pp("img_projection.0"),
617                    )?;
618                    let (a_w, a_b) = a.unquant_weight_bias().unwrap();
619                    tensors.push(("img_projection.0.weight".to_string(), a_w));
620                    if let Some(b) = a_b {
621                        tensors.push(("img_projection.0.bias".to_string(), b));
622                    }
623                    let b = mistralrs_quant::linear_b(
624                        dim_proj,
625                        dim_proj,
626                        true,
627                        &None,
628                        vb.pp("img_projection.2"),
629                    )?;
630                    let (b_w, b_b) = b.unquant_weight_bias().unwrap();
631                    tensors.push(("img_projection.2.weight".to_string(), b_w));
632                    if let Some(b) = b_b {
633                        tensors.push(("img_projection.2.bias".to_string(), b));
634                    }
635                    vec![
636                        Box::new(QuantMethodWrapper(a)),
637                        Box::new(candle_nn::Activation::Gelu),
638                        Box::new(QuantMethodWrapper(b)),
639                    ]
640                }
641                ("mlp", false) => {
642                    let dim_proj = hidden_size;
643                    let a = mistralrs_quant::linear_b(
644                        image_dim_out,
645                        dim_proj,
646                        true,
647                        &None,
648                        vb.pp("img_projection.0"),
649                    )?;
650                    let (a_w, a_b) = a.unquant_weight_bias().unwrap();
651                    tensors.push(("img_projection.0.weight".to_string(), a_w));
652                    if let Some(b) = a_b {
653                        tensors.push(("img_projection.0.bias".to_string(), b));
654                    }
655                    let b = mistralrs_quant::linear_b(
656                        dim_proj,
657                        dim_proj,
658                        true,
659                        &None,
660                        vb.pp("img_projection.2"),
661                    )?;
662                    let (b_w, b_b) = b.unquant_weight_bias().unwrap();
663                    tensors.push(("img_projection.2.weight".to_string(), b_w));
664                    if let Some(b) = b_b {
665                        tensors.push(("img_projection.2.bias".to_string(), b));
666                    }
667                    vec![
668                        Box::new(QuantMethodWrapper(a)),
669                        Box::new(candle_nn::Activation::Gelu),
670                        Box::new(QuantMethodWrapper(b)),
671                    ]
672                }
673                _ => {
674                    candle_core::bail!("projection_cls=`{projection_cls}` not implemented.");
675                }
676            };
677
678        let layer_idx = config.img_processor.layer_idx.unwrap_or(-2);
679        let type_feature = config
680            .img_processor
681            .type_feature
682            .clone()
683            .unwrap_or("patch".to_string());
684
685        Ok(Self {
686            wte,
687            image_dim_out,
688            num_img_tokens,
689            glb_gn,
690            sub_gn,
691            layer_idx,
692            type_feature,
693            image_processor,
694            layers: EmbeddingLayers(layers),
695            hd_transform_order,
696            use_hd_transform,
697            vocab_size: config.vocab_size,
698            tensors,
699        })
700    }
701
702    fn get_image_features(&self, pixel_values: &Tensor) -> Result<Tensor> {
703        let hidden_states = self
704            .image_processor
705            .forward_get_hidden_states(&pixel_values.to_dtype(self.wte.embeddings().dtype())?)?;
706        let img_feature =
707            hidden_states[(hidden_states.len() as isize + self.layer_idx) as usize].clone();
708        if self.type_feature == "patch" {
709            img_feature.i((.., 1..))
710        } else if self.type_feature == "cls_patch" {
711            Ok(img_feature)
712        } else {
713            candle_core::bail!("Unsupported image feature type {}", self.type_feature)
714        }
715    }
716
717    #[allow(non_snake_case)]
718    fn forward(
719        &self,
720        input_ids: &Tensor,
721        pixel_values: &Tensor,
722        image_sizes: Option<Vec<(usize, usize)>>,
723    ) -> Result<Tensor> {
724        let input_ids = input_ids.reshape(((), input_ids.dim(D::Minus1)?))?;
725
726        let input_ids_lt = input_ids.lt(0.0f64)?;
727        let input_ids_gt = input_ids.gt(-MAX_INPUT_ID)?;
728        // positions = torch.nonzero((input_ids < 0) & (input_ids > -MAX_INPUT_ID), as_tuple=False)
729        let positions = input_ids_lt.bitwise_and(&input_ids_gt)?.nonzero()?;
730        let target_dev = self.layers.0[0].device();
731        let target_dtype = self.layers.0[0].dtype();
732
733        let mut select = false;
734        // If some, use hd transform case and it contains num_img_toks
735        let mut hd_transform = None;
736        let mut image_set_tensor = None;
737        if positions.dim(0)? > 0 {
738            select = true;
739            // input_ids[positions[:, 0], positions[:, 1]]
740            if self.use_hd_transform && image_sizes.is_some() {
741                assert_eq!(pixel_values.dims().len(), 5);
742                let bs = pixel_values.dim(0)?;
743                let img_features = self.get_image_features(&pixel_values.flatten(0, 1)?)?;
744                let base_feat_dim = (img_features.dims()[1] as f32).sqrt() as usize;
745                assert_eq!(base_feat_dim, 24);
746
747                // bs x max_num_crops x (24x24) x C
748                let img_features =
749                    img_features.reshape((bs, (), base_feat_dim.pow(2), self.image_dim_out))?;
750                let C = self.image_dim_out;
751                let H = base_feat_dim;
752
753                let mut output_imgs = Vec::new();
754                let mut output_len = Vec::new();
755                for bs_ in 0..bs {
756                    let (h, w) = image_sizes.as_ref().unwrap()[bs_];
757                    let h = h / 336;
758                    let w = w / 336;
759                    let B_ = h * w;
760
761                    // 1 x (24x24) x 1024
762                    let global_img_feature = img_features.i((bs_, ..1))?;
763
764                    // 1 x 12 x 12 x 4096
765                    let glb_img = global_img_feature
766                        .reshape((1, H, H, C))?
767                        .reshape((1, H / 2, 2, H / 2, 2, C))?
768                        .contiguous()?
769                        .permute((0, 1, 3, 2, 4, 5))?
770                        .reshape((1, H / 2, H / 2, 4 * C))?
771                        .contiguous()?;
772                    let temp_glbl_gn = self
773                        .sub_gn
774                        .as_ref()
775                        .expect("Need `sub_gn` if `use_hd_transform`")
776                        .repeat((1, H / 2, 1, 1))?;
777
778                    // 1 x 156 x 4096
779                    let glb_img =
780                        Tensor::cat(&[glb_img, temp_glbl_gn], 2)?.reshape((1, (), 4 * C))?;
781
782                    // (max_num_crops-1) x (12x12) x C
783                    let sub_img = img_features.i((bs_, 1..))?;
784
785                    // 16x574x1024
786                    // Get rid of padding sub_img
787                    let sub_img = sub_img.i(..B_)?;
788
789                    // (num_crops, 12, 2, 12, 2, 1024) -> (num_crops, 12, 12, 2, 2, 1024) -> (num_crops, 12*12, 4*1024)
790                    let sub_img = sub_img
791                        .reshape((B_, H, H, C))?
792                        .reshape((B_, H / 2, 2, H / 2, 2, C))?
793                        .contiguous()?
794                        .permute((0, 1, 3, 2, 4, 5))?
795                        .reshape((B_, (), 4 * C))?
796                        .contiguous()?;
797                    let sub_img = sub_img
798                        .reshape(BigShapeWithOneHole((1usize, h, w, 12usize, 12usize, ())))?
799                        .permute((0, 1, 3, 2, 4, 5))?
800                        .reshape((1, h * 12, w * 12, 4 * C))?;
801                    let temp_sub_gn = self
802                        .sub_gn
803                        .as_ref()
804                        .expect("Need `sub_gn` if `use_hd_transform`")
805                        .repeat((1, h * 12, 1, 1))?;
806
807                    let sub_img =
808                        Tensor::cat(&[sub_img, temp_sub_gn], 2)?.reshape((1, (), 4 * C))?;
809
810                    // (1, num_img_tokens, 1024*4)
811
812                    match self.hd_transform_order.as_str() {
813                        "glb_sub" => {
814                            output_imgs.push(Tensor::cat(
815                                &[
816                                    glb_img,
817                                    self.glb_gn
818                                        .as_ref()
819                                        .expect("Need `glb_gn` if `use_hd_transform`")
820                                        .clone(),
821                                    sub_img,
822                                ],
823                                1,
824                            )?);
825                        }
826                        "sub_glb" => {
827                            output_imgs.push(Tensor::cat(
828                                &[
829                                    sub_img,
830                                    self.glb_gn
831                                        .as_ref()
832                                        .expect("Need `glb_gn` if `use_hd_transform`")
833                                        .clone(),
834                                    glb_img,
835                                ],
836                                1,
837                            )?);
838                        }
839                        other => {
840                            candle_core::bail!("Invalid hd_transform_order=`{other}`");
841                        }
842                    }
843
844                    let temp_len = (h * w + 1) * 144 + 1 + (h + 1) * 12;
845                    assert_eq!(temp_len, output_imgs.last().unwrap().dims()[1]);
846                    output_len.push(temp_len);
847                }
848
849                hd_transform = Some(output_len);
850                let mut image_set_tensor_inner = Vec::new();
851                for img in output_imgs {
852                    let layerout = self
853                        .layers
854                        .forward(&img.to_device(&target_dev)?.to_dtype(target_dtype)?)?;
855                    image_set_tensor_inner.push(layerout);
856                }
857                image_set_tensor = Some(Either::Left(image_set_tensor_inner));
858            } else if pixel_values.dims().len() == 4 {
859                let tt = self
860                    .get_image_features(pixel_values)?
861                    .to_device(&target_dev)?
862                    .to_dtype(target_dtype)?
863                    .reshape(((), self.image_dim_out))?;
864                let image_set_tensor_inner = self.layers.forward(&tt)?;
865                image_set_tensor = Some(Either::Right(image_set_tensor_inner));
866            } else if pixel_values.dims().len() == 3 {
867                let tt = pixel_values
868                    .to_device(&target_dev)?
869                    .to_dtype(target_dtype)?
870                    .reshape(((), self.image_dim_out))?;
871                let image_set_tensor_inner = self.layers.forward(&tt)?;
872                image_set_tensor = Some(Either::Right(image_set_tensor_inner));
873            } else {
874                unreachable!()
875            }
876        }
877
878        let input_ids = input_ids.clamp(0.0, self.vocab_size as f64)?;
879        let mut hidden_states = self.wte.forward(&input_ids)?;
880        if select {
881            match (hd_transform, image_set_tensor) {
882                (Some(output_lens), Some(Either::Left(image_set_tensors))) => {
883                    let mut idx = 0;
884                    for (i, cnt) in output_lens.into_iter().enumerate() {
885                        let img_set_tensor = image_set_tensors[i]
886                            .to_device(&target_dev)?
887                            .to_dtype(target_dtype)?;
888                        // hidden_states[positions[idx, 0], positions[idx, 1] : positions[idx, 1] + cnt] = ...
889                        let p_0 = positions.i((idx, 0))?.to_scalar::<u32>()? as usize;
890                        let p_1 = positions.i((idx, 1))?.to_scalar::<u32>()? as usize;
891                        hidden_states = hidden_states.slice_assign(
892                            &[&p_0, &(p_1..p_1 + cnt), &(..img_set_tensor.dims()[2])],
893                            &img_set_tensor,
894                        )?;
895                        idx += cnt;
896                    }
897                }
898                (None, Some(Either::Right(image_set_tensor))) => {
899                    let mut idx = 0;
900                    // Know len(img_embeds) == pixel_values.dim(0) == len(selected_g_values)
901                    // https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/dbcdaaacf52c8e40cf8de6d6ffa6ff6860e5f256/image_embedding_phi3_v.py#L259
902                    for i in 0..pixel_values.dim(0)? {
903                        let cnt = self.num_img_tokens;
904                        let img_set_tensor = image_set_tensor
905                            .i(i * cnt..(i + 1) * cnt)?
906                            .to_device(&target_dev)?
907                            .to_dtype(target_dtype)?;
908                        let p_0 = positions.i((idx, 0))?.to_scalar::<u32>()? as usize;
909                        let p_1 = positions.i((idx, 1))?.to_scalar::<u32>()? as usize;
910                        // hidden_states[positions[idx, 0], positions[idx, 1] : positions[idx, 1] + cnt] = ...
911                        hidden_states = hidden_states.slice_assign(
912                            &[&p_0, &(p_1..p_1 + cnt), &(..img_set_tensor.dims()[2])],
913                            &img_set_tensor,
914                        )?;
915                        idx += cnt;
916                    }
917                }
918                _ => unreachable!(),
919            }
920        }
921
922        Ok(hidden_states)
923    }
924
925    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
926        let uvb = UnVarBuilder::new();
927
928        if let Some(glb_gn) = self.glb_gn.clone() {
929            uvb.add_tensor("glb_GN", glb_gn);
930        }
931        if let Some(sub_gn) = self.sub_gn.clone() {
932            uvb.add_tensor("sub_GN", sub_gn);
933        }
934        uvb.extend(self.tensors.clone());
935        uvb.pp("img_processor.vision_model")
936            .extend(self.image_processor.residual_tensors());
937
938        uvb.to_safetensors()
939    }
940}
941
942// =================== ============= ===================
943
944pub struct Model {
945    vision_embed_tokens: ImageEmbedding,
946    embed_tokens: candle_nn::Embedding,
947    layers: Vec<DecoderLayer>,
948    norm: RmsNorm,
949    lm_head: Arc<dyn QuantMethod>,
950    device: Device,
951    cache: EitherCache,
952    max_seq_len: usize,
953    mapper: Box<dyn DeviceMapper + Send + Sync>,
954    sliding_window: Option<usize>,
955    cfg: ModelConfigMetadata,
956}
957
958impl Model {
959    pub fn new(
960        cfg: &Config,
961        vb: ShardedVarBuilder,
962        _is_gptx: bool,
963        normal_loading_metadata: NormalLoadingMetadata,
964        attention_mechanism: AttentionImplementation,
965    ) -> Result<Self> {
966        let mapper = normal_loading_metadata.mapper;
967        let vb_m = vb.pp("model");
968
969        let embed_tokens = layers::embedding(
970            cfg.vocab_size,
971            cfg.hidden_size,
972            mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
973            &cfg.quantization_config,
974        )?;
975        let vision_embed_tokens = ImageEmbedding::new(
976            cfg,
977            embed_tokens.clone(),
978            &cfg.embd_layer,
979            mapper.set_nm_device(vb_m.pp("vision_embed_tokens"), false),
980        )?;
981        let vb_l = vb_m.pp("layers");
982        let mut ropes = HashMap::new();
983        for layer_idx in 0..cfg.num_hidden_layers {
984            let device = mapper
985                .device_for(layer_idx, false)
986                .unwrap_or(&normal_loading_metadata.real_device);
987            ropes.insert(
988                device.location(),
989                Arc::new(PhiRotaryEmbedding::new(vb.dtype(), cfg.clone(), device)?),
990            );
991        }
992        let layers = NiceProgressBar::<_, 'b'>(
993            0..cfg.num_hidden_layers,
994            "Loading repeating layers",
995            &normal_loading_metadata.multi_progress,
996        )
997        .par_iter_if_isq(|layer_idx| {
998            let device = mapper
999                .device_for(layer_idx, false)
1000                .unwrap_or(&normal_loading_metadata.real_device);
1001            let rotary_emb = ropes
1002                .get(&device.location())
1003                .expect("No RoPE for device location!")
1004                .clone();
1005            let paged_attn = match &attention_mechanism {
1006                AttentionImplementation::Eager => None,
1007                AttentionImplementation::PagedAttention => {
1008                    Some(PagedAttention::new(cfg.head_dim(), device, None)?)
1009                }
1010            };
1011            DecoderLayer::new(
1012                rotary_emb,
1013                cfg,
1014                vb_l.pp(layer_idx),
1015                &*mapper,
1016                layer_idx,
1017                normal_loading_metadata.loading_isq,
1018                paged_attn,
1019            )
1020        })?;
1021        let norm = RmsNorm::new(
1022            cfg.hidden_size,
1023            cfg.rms_norm_eps,
1024            mapper.set_nm_device(vb_m.pp("norm"), false),
1025        )?;
1026        let lm_head = if !cfg.tie_word_embeddings {
1027            ReplicatedLayer::new(
1028                cfg.hidden_size,
1029                cfg.vocab_size,
1030                &cfg.quantization_config,
1031                false,
1032                mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
1033            )?
1034        } else {
1035            ReplicatedLayer::from_linear(candle_nn::Linear::new(
1036                mapper.cast_nm_device(
1037                    embed_tokens.embeddings(),
1038                    normal_loading_metadata.loading_isq,
1039                )?,
1040                None,
1041            ))?
1042        };
1043
1044        Ok(Self {
1045            vision_embed_tokens,
1046            layers,
1047            norm,
1048            lm_head,
1049            device: normal_loading_metadata.real_device,
1050            cache: EitherCache::Normal(NormalCache::new_sliding(
1051                cfg.num_hidden_layers,
1052                cfg.max_position_embeddings,
1053                cfg.sliding_window,
1054            )),
1055            max_seq_len: cfg.max_position_embeddings,
1056            sliding_window: cfg.sliding_window,
1057            embed_tokens,
1058            cfg: ModelConfigMetadata {
1059                max_seq_len: cfg.max_position_embeddings,
1060                num_layers: cfg.num_hidden_layers,
1061                hidden_size: cfg.hidden_size,
1062                num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
1063                num_kv_heads: (cfg.num_key_value_heads / mapper.get_comm_for(0)?.world_size())
1064                    .max(1),
1065                sliding_window: cfg.sliding_window,
1066                k_head_dim: cfg.head_dim(),
1067                v_head_dim: cfg.head_dim(),
1068            },
1069            mapper,
1070        })
1071    }
1072
1073    #[allow(clippy::too_many_arguments)]
1074    pub fn forward(
1075        &self,
1076        input_ids: &Tensor,
1077        pixel_values: Option<Tensor>,
1078        seqlen_offsets: &[usize],
1079        position_ids: &[usize],
1080        context_lens: Vec<(usize, usize)>,
1081        image_sizes: Option<Vec<(usize, usize)>>,
1082        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
1083        flash_params: &FlashParams,
1084    ) -> Result<Tensor> {
1085        let mut xs = if let Some(ref pixel_values) = pixel_values {
1086            self.vision_embed_tokens
1087                .forward(input_ids, pixel_values, image_sizes)?
1088        } else {
1089            self.embed_tokens.forward(input_ids)?
1090        };
1091        let cache = &mut self.cache.normal().0;
1092        let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
1093            input_ids,
1094            metadata
1095                .as_ref()
1096                .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
1097                .unwrap_or(&*cache as &dyn PastKvLenCache),
1098            self.sliding_window,
1099            xs.dtype(),
1100            self.cfg.num_attn_heads,
1101        )?;
1102        let attention_mask = attention_mask.filter(|_| {
1103            metadata
1104                .as_ref()
1105                .map(|(_, meta)| meta.is_first_prompt_chunk)
1106                .unwrap_or(true)
1107        });
1108
1109        for (i, layer) in self.layers.iter().enumerate() {
1110            xs = self.mapper.map(xs, i)?;
1111            xs = layer.forward(
1112                &xs,
1113                attention_mask
1114                    .as_ref()
1115                    .map(|m| m.to_device(xs.device()).unwrap())
1116                    .as_ref(),
1117                seqlen_offsets,
1118                position_ids,
1119                &mut cache[i],
1120                metadata
1121                    .as_ref()
1122                    .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
1123                flash_params,
1124            )?
1125        }
1126        let xs = xs.to_device(&self.device)?;
1127        let mut xs = xs.apply(&self.norm)?;
1128        if let Some(t) = self.lm_head.quantized_act_type() {
1129            xs = xs.to_dtype(t)?;
1130        }
1131        extract_logits(&MatMul.qmethod_matmul(&xs, &*self.lm_head)?, context_lens)
1132    }
1133}
1134
1135impl IsqModel for Model {
1136    fn get_layers(
1137        &mut self,
1138    ) -> (
1139        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
1140        &dyn DeviceMapper,
1141    ) {
1142        let mut tensors = Vec::new();
1143        tensors.push((&mut self.lm_head, None));
1144        for (i, layer) in self.layers.iter_mut().enumerate() {
1145            tensors.push((&mut layer.self_attn.qkv_proj, Some(i)));
1146            tensors.push((&mut layer.self_attn.o_proj, Some(i)));
1147            tensors.extend(
1148                layer
1149                    .mlp
1150                    .get_isq_layers()
1151                    .into_iter()
1152                    .map(|m| (m, Some(i)))
1153                    .collect::<Vec<_>>(),
1154            );
1155        }
1156        (tensors, &*self.mapper)
1157    }
1158
1159    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
1160        let uvb = UnVarBuilder::new();
1161
1162        let uvb_m = uvb.pp("model");
1163        uvb_m.pp("embed_tokens").add(&self.embed_tokens);
1164        uvb_m.pp("norm").add(&self.norm);
1165        uvb_m
1166            .pp("vision_embed_tokens")
1167            .extend(self.vision_embed_tokens.residual_tensors());
1168
1169        for (layer_idx, layer) in self.layers.iter().enumerate() {
1170            let uvb_l = uvb_m.pp("layers").pp(layer_idx);
1171            uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
1172            uvb_l
1173                .pp("post_attention_layernorm")
1174                .add(&layer.post_attention_layernorm);
1175        }
1176
1177        uvb.to_safetensors()
1178    }
1179}
1180
1181#[derive(Default)]
1182pub(crate) struct Phi3VisionSpecificArgs {
1183    pub image_sizes: Option<Vec<(usize, usize)>>,
1184}
1185
1186impl VisionModel for Model {
1187    fn forward(
1188        &self,
1189        input_ids: &Tensor,
1190        pixel_values: Option<Tensor>,
1191        seqlen_offsets: &[usize],
1192        context_lens: Vec<(usize, usize)>,
1193        position_ids: Vec<usize>,
1194        model_specific_args: Box<dyn Any>,
1195        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
1196        flash_params: &FlashParams,
1197    ) -> Result<Tensor> {
1198        let Phi3VisionSpecificArgs { image_sizes } = *model_specific_args
1199            .downcast()
1200            .expect("Cannot downcast into `Phi3VisionSpecificArgs`");
1201        self.forward(
1202            input_ids,
1203            pixel_values,
1204            seqlen_offsets,
1205            &position_ids,
1206            context_lens,
1207            image_sizes,
1208            metadata,
1209            flash_params,
1210        )
1211    }
1212    fn cache(&self) -> &EitherCache {
1213        &self.cache
1214    }
1215    fn cache_mut(&mut self) -> &mut EitherCache {
1216        &mut self.cache
1217    }
1218    fn device(&self) -> &Device {
1219        &self.device
1220    }
1221    fn max_seq_len(&self) -> usize {
1222        self.max_seq_len
1223    }
1224    fn config(&self) -> &ModelConfigMetadata {
1225        &self.cfg
1226    }
1227    fn default_model_specific_args(&self, _input_ids: &Tensor) -> Box<dyn Any> {
1228        Box::new(Phi3VisionSpecificArgs::default())
1229    }
1230}
1231
1232impl AnyMoeBaseModelMixin for Model {
1233    fn get_mlps(&self) -> Vec<&dyn MlpLayer> {
1234        let mut mlps = Vec::new();
1235        for layer in &self.layers {
1236            mlps.push(&*layer.mlp);
1237        }
1238        mlps
1239    }
1240    fn get_mlps_mut(&mut self) -> Vec<&mut Box<dyn MlpLayer>> {
1241        let mut mlps = Vec::new();
1242        for layer in &mut self.layers {
1243            mlps.push(&mut layer.mlp);
1244        }
1245        mlps
1246    }
1247    fn create_anymoe_layers(
1248        &mut self,
1249        additional_vbs: Vec<ShardedVarBuilder>,
1250        config: AnyMoeConfig,
1251        (prefix, mlp): (String, String),
1252        mut layers: Vec<usize>,
1253        expert_type: AnyMoeExpertType,
1254        gate_vb: Option<ShardedVarBuilder>,
1255    ) -> Result<()> {
1256        let mut experts: Vec<Vec<Box<dyn MlpLayer>>> = Vec::new();
1257        if layers.is_empty() {
1258            layers = (0..self.layers.len()).collect::<Vec<_>>();
1259        }
1260        for _ in 0..layers.len() {
1261            experts.push(Vec::new());
1262        }
1263        for vb in additional_vbs {
1264            let vb = vb.pp(&prefix);
1265            for (layer, row) in experts.iter_mut().enumerate() {
1266                if !layers.contains(&layer) {
1267                    continue;
1268                }
1269
1270                let intermediate_size = self.layers[layer].mlp.get_params()[1];
1271                let hidden_size = self.layers[layer].mlp.get_params()[0];
1272                match expert_type {
1273                    AnyMoeExpertType::FineTuned => {
1274                        row.push(Box::new(Mlp::new(
1275                            &Config {
1276                                intermediate_size: self.layers[layer].mlp.get_params()[1],
1277                                hidden_size: self.layers[layer].mlp.get_params()[0],
1278                                ..Default::default()
1279                            },
1280                            vb.pp(layer).pp(&mlp),
1281                        )?));
1282                    }
1283                    AnyMoeExpertType::LoraAdapter {
1284                        rank,
1285                        alpha,
1286                        ref target_modules,
1287                    } => {
1288                        let vb_mlp = vb.pp(layer).pp(&mlp);
1289
1290                        let gate_up_proj_delta =
1291                            if target_modules.contains(&"gate_up_proj".to_string()) {
1292                                Some(get_delta_from_lora_ab!(
1293                                    vb_mlp,
1294                                    rank,
1295                                    alpha,
1296                                    (hidden_size, 2 * intermediate_size),
1297                                    "gate_up_proj"
1298                                ))
1299                            } else {
1300                                None
1301                            };
1302                        let down_proj_delta = if target_modules.contains(&"down_proj".to_string()) {
1303                            Some(get_delta_from_lora_ab!(
1304                                vb_mlp,
1305                                rank,
1306                                alpha,
1307                                (hidden_size, intermediate_size),
1308                                "down_proj"
1309                            ))
1310                        } else {
1311                            None
1312                        };
1313
1314                        row.push(
1315                            self.layers[layer]
1316                                .mlp
1317                                .new_added_delta(vec![gate_up_proj_delta, down_proj_delta])?,
1318                        );
1319                    }
1320                }
1321            }
1322        }
1323        for (layer, expert) in layers.into_iter().zip(experts) {
1324            let mut experts_all = vec![self.layers[layer].mlp.clone()];
1325            experts_all.extend(expert);
1326            let (dtype, device) = self.layers[layer].mlp.dtype_device();
1327            self.layers[layer].mlp = Box::new(MoeMlp::new(
1328                experts_all,
1329                config.clone(),
1330                dtype,
1331                &device,
1332                layer,
1333                gate_vb.as_ref(),
1334            )?);
1335        }
1336        Ok(())
1337    }
1338    fn amoe_supported(&self) -> bool {
1339        true
1340    }
1341}