1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3pub(crate) mod phi3_inputs_processor;
4
5use candle_core::{
8 shape::ShapeWithOneHole, DType, Device, IndexOp, Module, Result, Shape, Tensor, D,
9};
10use either::Either;
11use mistralrs_quant::{QuantMethod, QuantizedConfig, ReplicatedLayer, ShardedVarBuilder};
12use std::{any::Any, collections::HashMap, fmt::Debug, sync::Arc};
13
14use crate::{
15 amoe::{AnyMoeBaseModelMixin, AnyMoeTrainableLayer, MlpLayer, MoeMlp},
16 attention::SdpaParams,
17 device_map::DeviceMapper,
18 get_delta_from_lora_ab,
19 layers::{
20 self, Activation, CausalMasker, MatMul, PhiRopeConfig, PhiRopeScalingConfig,
21 PhiRotaryEmbedding, RmsNorm, Sdpa,
22 },
23 layers_masker::PastKvLenCache,
24 ops::{BitWiseOp, NonZeroOp},
25 paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
26 pipeline::{
27 extract_logits,
28 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
29 EitherCache, IsqModel, KvCache, NormalCache, NormalLoadingMetadata, VisionModel,
30 },
31 serde_default_fn,
32 utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
33 vision_models::clip::{ClipConfig, ClipVisionTransformer},
34 AnyMoeConfig, AnyMoeExpertType,
35};
36
37use super::clip;
38
39#[derive(Debug, Clone, serde::Deserialize, Default)]
40pub struct EmbedLayerConfig {
41 pub hd_transform_order: Option<String>,
42 pub projection_cls: Option<String>,
43 pub use_hd_transform: Option<bool>,
44 pub with_learnable_separator: Option<bool>,
45}
46
47#[derive(Debug, Clone, serde::Deserialize, Default)]
48pub struct ImageProcessorConfig {
49 pub image_dim_out: usize,
50 pub name: String,
51 pub num_img_tokens: usize,
52 pub layer_idx: Option<isize>,
53 pub type_feature: Option<String>,
54}
55
56serde_default_fn!(bool, d_flash_attn, false);
57serde_default_fn!(bool, word_emb_default, false);
58
59#[derive(Debug, Clone, serde::Deserialize, Default)]
60pub struct Config {
61 pub vocab_size: usize,
62 pub hidden_act: Activation,
63 pub hidden_size: usize,
64 pub intermediate_size: usize,
65 pub num_hidden_layers: usize,
66 pub num_attention_heads: usize,
67 pub num_key_value_heads: usize,
68 pub rms_norm_eps: f64,
69 pub rope_theta: f64,
70 pub bos_token_id: Option<u32>,
71 pub eos_token_id: Option<u32>,
72 pub rope_scaling: Option<PhiRopeScalingConfig>,
73 pub max_position_embeddings: usize,
74 #[serde(default = "d_flash_attn")]
75 pub use_flash_attn: bool,
76 pub sliding_window: Option<usize>,
77 pub original_max_position_embeddings: usize,
78 pub embd_layer: EmbedLayerConfig,
79 pub img_processor: ImageProcessorConfig,
80 pub quantization_config: Option<QuantizedConfig>,
81 #[serde(default = "word_emb_default")]
82 pub tie_word_embeddings: bool,
83}
84
85impl From<Config> for PhiRopeConfig {
86 fn from(val: Config) -> Self {
87 PhiRopeConfig {
88 rope_scaling: val.rope_scaling,
89 max_position_embeddings: val.max_position_embeddings,
90 original_max_position_embeddings: val.original_max_position_embeddings,
91 rope_theta: val.rope_theta,
92 head_dim: val.hidden_size / val.num_attention_heads,
93 partial_rotary_factor: None,
94 }
95 }
96}
97
98impl Config {
99 pub fn head_dim(&self) -> usize {
100 self.hidden_size / self.num_attention_heads
101 }
102}
103
104trait ModuleWithMetadata: Module + Debug + Send + Sync {
105 fn device(&self) -> Device;
106 fn dtype(&self) -> DType;
107}
108
109#[derive(Debug)]
110struct QuantMethodWrapper(Arc<dyn QuantMethod>);
111
112impl Module for QuantMethodWrapper {
113 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
114 self.0.forward(xs)
115 }
116}
117
118impl ModuleWithMetadata for QuantMethodWrapper {
119 fn device(&self) -> Device {
120 self.0.unquant_weight_bias().unwrap().0.device().clone()
121 }
122 fn dtype(&self) -> DType {
123 self.0.unquant_weight_bias().unwrap().0.dtype()
124 }
125}
126
127impl ModuleWithMetadata for candle_nn::Activation {
128 fn device(&self) -> Device {
129 unreachable!()
130 }
131 fn dtype(&self) -> DType {
132 unreachable!()
133 }
134}
135
136#[derive(Debug)]
137struct BigShapeWithOneHole((usize, usize, usize, usize, usize, ()));
138
139fn hole_size(el_count: usize, prod_d: usize, s: &dyn std::fmt::Debug) -> Result<usize> {
140 if prod_d == 0 {
141 candle_core::bail!("cannot reshape tensor of {el_count} elements to {s:?}")
142 }
143 if el_count % prod_d != 0 {
144 candle_core::bail!("cannot reshape tensor with {el_count} elements to {s:?}")
145 }
146 Ok(el_count / prod_d)
147}
148
149impl ShapeWithOneHole for BigShapeWithOneHole {
150 fn into_shape(self, el_count: usize) -> Result<Shape> {
151 let (d1, d2, d3, d4, d5, ()) = self.0;
152 let d = hole_size(el_count, d1 * d2 * d3 * d4 * d5, &self)?;
153 Ok((d1, d2, d3, d4, d5, d).into())
154 }
155}
156
157struct Attention {
160 qkv_proj: Arc<dyn QuantMethod>,
161 o_proj: Arc<dyn QuantMethod>,
162 num_heads: usize,
163 num_kv_heads: usize,
164 head_dim: usize,
165 rotary_emb: Arc<PhiRotaryEmbedding>,
166 paged_attn: Option<PagedAttention>,
167 sdpa_params: SdpaParams,
168}
169
170impl Attention {
171 fn new(
172 rotary_emb: Arc<PhiRotaryEmbedding>,
173 cfg: &Config,
174 vb: ShardedVarBuilder,
175 paged_attn: Option<PagedAttention>,
176 ) -> Result<Self> {
177 let num_heads = cfg.num_attention_heads;
178 let num_kv_heads = cfg.num_key_value_heads;
179 let head_dim = cfg.head_dim();
180 let op_size = num_heads * head_dim + 2 * num_kv_heads * head_dim;
181
182 let qkv_proj = mistralrs_quant::linear_no_bias(
184 cfg.hidden_size,
185 op_size,
186 &cfg.quantization_config,
187 vb.pp("qkv_proj"),
188 )?;
189
190 let o_proj = mistralrs_quant::linear_no_bias(
191 num_heads * head_dim,
192 cfg.hidden_size,
193 &cfg.quantization_config,
194 vb.pp("o_proj"),
195 )?;
196
197 Ok(Self {
198 qkv_proj,
199 o_proj,
200 rotary_emb,
201 num_heads,
202 num_kv_heads,
203 head_dim,
204 paged_attn,
205 sdpa_params: SdpaParams {
206 n_kv_groups: num_heads / num_kv_heads,
207 use_flash_attn: cfg.use_flash_attn,
208 softcap: None,
209 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
210 sliding_window: cfg.sliding_window,
211 },
212 })
213 }
214
215 #[allow(clippy::too_many_arguments)]
216 fn forward(
217 &self,
218 xs: &Tensor,
219 attention_mask: Option<&Tensor>,
220 seqlen_offsets: &[usize],
221 position_ids: &[usize],
222 kv_cache: &mut KvCache,
223 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
224 flash_params: &FlashParams,
225 ) -> Result<Tensor> {
226 let (b_sz, q_len, _) = xs.dims3()?;
227
228 let original_dtype = xs.dtype();
229 let mut xs = xs.clone();
230 if let Some(t) = self.qkv_proj.quantized_act_type() {
231 xs = xs.to_dtype(t)?;
232 }
233 let mut qkv = MatMul.qmethod_matmul(&xs, &*self.qkv_proj)?;
234 if self.qkv_proj.quantized_act_type().is_some() {
235 qkv = qkv.to_dtype(original_dtype)?;
236 }
237 let query_pos = self.num_heads * self.head_dim;
238 let q = qkv.narrow(D::Minus1, 0, query_pos)?;
239 let k = qkv.narrow(D::Minus1, query_pos, self.num_kv_heads * self.head_dim)?;
240 let v = qkv.narrow(
241 D::Minus1,
242 query_pos + self.num_kv_heads * self.head_dim,
243 self.num_kv_heads * self.head_dim,
244 )?;
245
246 let (q, k, v) = if q_len != 1 {
247 let q = q
248 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
249 .transpose(1, 2)?;
250 let k = k
251 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
252 .transpose(1, 2)?;
253 let v = v
254 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
255 .transpose(1, 2)?;
256 (q, k, v)
257 } else {
258 let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
259 let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
260 let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
261 (q, k, v)
262 };
263
264 let (q, k) = self
265 .rotary_emb
266 .forward(&q, &k, seqlen_offsets, position_ids)?;
267
268 let mut attn_output = match &self.paged_attn {
269 Some(paged_attn) => match metadata {
270 Some(((key_cache, value_cache), input_metadata)) => paged_attn.forward(
271 &q,
272 &k.contiguous()?,
273 &v.contiguous()?,
274 attention_mask,
275 Some(key_cache),
276 Some(value_cache),
277 input_metadata,
278 &self.sdpa_params,
279 Some(flash_params),
280 )?,
281 None => {
282 let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
285 assert!(attention_mask.is_some());
287 paged_attn.forward(
288 &q,
289 &k.contiguous()?,
290 &v.contiguous()?,
291 attention_mask,
292 None,
293 None,
294 &input_metadata,
295 &self.sdpa_params,
296 Some(flash_params),
297 )?
298 }
299 },
300 None => {
301 let (k, v) = kv_cache.append(&k, &v)?;
302
303 Sdpa.run_attention(
304 &q,
305 &k,
306 &v,
307 attention_mask,
308 Some(flash_params),
309 &self.sdpa_params,
310 )?
311 }
312 };
313
314 if let Some(t) = self.qkv_proj.quantized_act_type() {
315 attn_output = attn_output.to_dtype(t)?;
316 }
317 attn_output = if attention_mask.is_some() {
318 attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?
319 } else {
320 attn_output.reshape((b_sz, q_len, ()))?
321 };
322 let mut res = MatMul.qmethod_matmul(&attn_output, &*self.o_proj)?;
323 if self.qkv_proj.quantized_act_type().is_some() {
324 res = res.to_dtype(original_dtype)?;
325 }
326 Ok(res)
327 }
328}
329
330#[derive(Clone)]
331struct Mlp {
332 gate_up_proj: Arc<dyn QuantMethod>,
333 down_proj: Arc<dyn QuantMethod>,
334 act_fn: Activation,
335 i_size: usize,
336 params: Vec<usize>,
337}
338
339impl Mlp {
340 fn new(cfg: &Config, vb: ShardedVarBuilder) -> Result<Self> {
341 let hidden_size = cfg.hidden_size;
342 let i_size = cfg.intermediate_size;
343
344 let gate_up_proj = mistralrs_quant::linear_no_bias(
346 hidden_size,
347 2 * i_size,
348 &cfg.quantization_config,
349 vb.pp("gate_up_proj"),
350 )?;
351
352 let down_proj = mistralrs_quant::linear_no_bias(
353 i_size,
354 hidden_size,
355 &cfg.quantization_config,
356 vb.pp("down_proj"),
357 )?;
358
359 Ok(Self {
360 gate_up_proj,
361 down_proj,
362 act_fn: cfg.hidden_act,
363 i_size,
364 params: vec![hidden_size, i_size],
365 })
366 }
367}
368
369impl AnyMoeTrainableLayer for Mlp {}
370
371impl MlpLayer for Mlp {
372 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
373 let original_dtype = xs.dtype();
374 let mut xs = xs.clone();
375 if let Some(t) = self.gate_up_proj.quantized_act_type() {
376 xs = xs.to_dtype(t)?;
377 }
378 let up_states = MatMul.qmethod_matmul(&xs, &*self.gate_up_proj)?;
379 let gate = up_states.narrow(D::Minus1, 0, self.i_size)?;
380 let up_states = up_states.narrow(D::Minus1, self.i_size, self.i_size)?;
381 let up_states = (up_states * gate.apply(&self.act_fn))?;
382 let mut res = MatMul.qmethod_matmul(&up_states, &*self.down_proj)?;
383 if self.gate_up_proj.quantized_act_type().is_some() {
384 res = res.to_dtype(original_dtype)?;
385 }
386 Ok(res)
387 }
388 fn get_isq_layers(&mut self) -> Vec<&mut Arc<dyn QuantMethod>> {
389 vec![&mut self.gate_up_proj, &mut self.down_proj]
390 }
391 fn clone(&self) -> Box<dyn MlpLayer> {
392 Box::new(Clone::clone(self))
393 }
394 fn get_params(&self) -> &[usize] {
395 &self.params
396 }
397 fn hidden_act(&self) -> Activation {
398 self.act_fn
399 }
400 fn new_added_delta(&self, deltas: Vec<Option<Tensor>>) -> Result<Box<dyn MlpLayer>> {
402 let new_gate_up = if let Some(ref delta) = deltas[0] {
403 self.gate_up_proj.add_delta_w(delta)?
404 } else {
405 self.gate_up_proj.clone()
406 };
407 let new_down = if let Some(ref delta) = deltas[1] {
408 self.down_proj.add_delta_w(delta)?
409 } else {
410 self.down_proj.clone()
411 };
412
413 Ok(Box::new(Self {
414 gate_up_proj: new_gate_up,
415 down_proj: new_down,
416 act_fn: self.act_fn,
417 i_size: self.i_size,
418 params: self.params.clone(),
419 }))
420 }
421
422 fn dtype_device(&self) -> (DType, Device) {
423 self.gate_up_proj.dtype_and_device()
424 }
425}
426
427struct DecoderLayer {
428 self_attn: Attention,
429 mlp: Box<dyn MlpLayer>,
430 input_layernorm: RmsNorm,
431 post_attention_layernorm: RmsNorm,
432}
433
434impl DecoderLayer {
435 fn new(
436 rotary_emb: Arc<PhiRotaryEmbedding>,
437 cfg: &Config,
438 vb: ShardedVarBuilder,
439 mapper: &dyn DeviceMapper,
440 layer_idx: usize,
441 loading_isq: bool,
442 paged_attn: Option<PagedAttention>,
443 ) -> Result<Self> {
444 let self_attn = Attention::new(
445 rotary_emb,
446 cfg,
447 mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
448 paged_attn,
449 )?;
450 let mlp = Mlp::new(cfg, mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq))?;
451 let input_layernorm = RmsNorm::new(
452 cfg.hidden_size,
453 cfg.rms_norm_eps,
454 mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
455 )?;
456 let post_attention_layernorm = RmsNorm::new(
457 cfg.hidden_size,
458 cfg.rms_norm_eps,
459 mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
460 )?;
461 Ok(Self {
462 self_attn,
463 mlp: Box::new(mlp),
464 input_layernorm,
465 post_attention_layernorm,
466 })
467 }
468
469 #[allow(clippy::too_many_arguments)]
470 fn forward(
471 &self,
472 xs: &Tensor,
473 attention_mask: Option<&Tensor>,
474 seqlen_offsets: &[usize],
475 position_ids: &[usize],
476 kv_cache: &mut KvCache,
477 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
478 flash_params: &FlashParams,
479 ) -> Result<Tensor> {
480 let residual = xs;
481 let xs = self.input_layernorm.forward(xs)?;
482 let xs = self
483 .self_attn
484 .forward(
485 &xs,
486 attention_mask,
487 seqlen_offsets,
488 position_ids,
489 kv_cache,
490 metadata,
491 flash_params,
492 )
493 .unwrap();
494 let xs = (xs + residual)?;
495 let residual = &xs;
496 let xs = self
497 .mlp
498 .forward(&xs.apply(&self.post_attention_layernorm)?)?;
499 residual + xs
500 }
501}
502
503const MAX_INPUT_ID: f64 = 1e9;
508
509#[derive(Debug)]
510struct EmbeddingLayers(Vec<Box<dyn ModuleWithMetadata>>);
511
512impl Module for EmbeddingLayers {
513 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
514 let mut xs = xs.clone();
515 for layer in &self.0 {
516 xs = layer.forward(&xs)?;
517 }
518 Ok(xs)
519 }
520}
521
522#[derive(Debug)]
523pub struct ImageEmbedding {
524 wte: candle_nn::Embedding,
525 image_dim_out: usize,
526 num_img_tokens: usize,
527 glb_gn: Option<Tensor>,
528 sub_gn: Option<Tensor>,
529 layers: EmbeddingLayers,
530 type_feature: String,
531 layer_idx: isize,
532 image_processor: ClipVisionTransformer,
533 hd_transform_order: String,
534 use_hd_transform: bool,
535 vocab_size: usize,
536 tensors: Vec<(String, Tensor)>,
537}
538
539pub(crate) const PHI3V_CLIP_CONFIG: ClipConfig = ClipConfig {
540 hidden_act: clip::Activation::QuickGelu,
541 hidden_size: 1024,
542 image_size: 336,
543 intermediate_size: 4096,
544 num_attention_heads: 16,
545 num_channels: 3,
546 num_hidden_layers: 24,
547 patch_size: 14,
548};
549
550impl ImageEmbedding {
551 fn new(
552 config: &Config,
553 wte: candle_nn::Embedding,
554 embed_config: &EmbedLayerConfig,
555 vb: ShardedVarBuilder,
556 ) -> Result<Self> {
557 let hidden_size = config.hidden_size;
558 if config.img_processor.name != "clip_vision_model" {
559 candle_core::bail!(
560 "img_processor=`{}` nor supported.",
561 config.img_processor.name
562 );
563 }
564 let image_dim_out = config.img_processor.image_dim_out;
565 let num_img_tokens = config.img_processor.num_img_tokens;
566
567 let image_processor =
569 ClipVisionTransformer::new(vb.pp("img_processor.vision_model"), &PHI3V_CLIP_CONFIG)?;
570
571 let use_hd_transform = embed_config.use_hd_transform.unwrap_or(false);
573 let with_learnable_separator = embed_config.with_learnable_separator.unwrap_or(false);
574 let hd_transform_order = embed_config
575 .hd_transform_order
576 .clone()
577 .unwrap_or("glb_sub".to_string());
578 assert_eq!(use_hd_transform, with_learnable_separator);
579 let (glb_gn, sub_gn) = if with_learnable_separator {
580 let glb_gn = vb.get((1, 1, image_dim_out * 4), "glb_GN")?;
581 let sub_gn = vb.get((1, 1, 1, image_dim_out * 4), "sub_GN")?;
582 (Some(glb_gn), Some(sub_gn))
583 } else {
584 (None, None)
585 };
586
587 let projection_cls = embed_config
589 .projection_cls
590 .clone()
591 .unwrap_or("linear".to_string());
592
593 let mut tensors = Vec::new();
594 let layers: Vec<Box<dyn ModuleWithMetadata>> =
595 match (projection_cls.as_str(), use_hd_transform) {
596 ("linear", _) => {
597 let a = mistralrs_quant::linear_b(
598 image_dim_out,
599 hidden_size,
600 true,
601 &None,
602 vb.pp("img_projection"),
603 )?;
604 let (a_w, a_b) = a.unquant_weight_bias().unwrap();
605 tensors.push(("img_projection.weight".to_string(), a_w));
606 if let Some(b) = a_b {
607 tensors.push(("img_projection.bias".to_string(), b));
608 }
609 vec![Box::new(QuantMethodWrapper(a))]
610 }
611 ("mlp", true) => {
612 let dim_proj = hidden_size;
613 let a = mistralrs_quant::linear_b(
614 image_dim_out * 4,
615 dim_proj,
616 true,
617 &None,
618 vb.pp("img_projection.0"),
619 )?;
620 let (a_w, a_b) = a.unquant_weight_bias().unwrap();
621 tensors.push(("img_projection.0.weight".to_string(), a_w));
622 if let Some(b) = a_b {
623 tensors.push(("img_projection.0.bias".to_string(), b));
624 }
625 let b = mistralrs_quant::linear_b(
626 dim_proj,
627 dim_proj,
628 true,
629 &None,
630 vb.pp("img_projection.2"),
631 )?;
632 let (b_w, b_b) = b.unquant_weight_bias().unwrap();
633 tensors.push(("img_projection.2.weight".to_string(), b_w));
634 if let Some(b) = b_b {
635 tensors.push(("img_projection.2.bias".to_string(), b));
636 }
637 vec![
638 Box::new(QuantMethodWrapper(a)),
639 Box::new(candle_nn::Activation::Gelu),
640 Box::new(QuantMethodWrapper(b)),
641 ]
642 }
643 ("mlp", false) => {
644 let dim_proj = hidden_size;
645 let a = mistralrs_quant::linear_b(
646 image_dim_out,
647 dim_proj,
648 true,
649 &None,
650 vb.pp("img_projection.0"),
651 )?;
652 let (a_w, a_b) = a.unquant_weight_bias().unwrap();
653 tensors.push(("img_projection.0.weight".to_string(), a_w));
654 if let Some(b) = a_b {
655 tensors.push(("img_projection.0.bias".to_string(), b));
656 }
657 let b = mistralrs_quant::linear_b(
658 dim_proj,
659 dim_proj,
660 true,
661 &None,
662 vb.pp("img_projection.2"),
663 )?;
664 let (b_w, b_b) = b.unquant_weight_bias().unwrap();
665 tensors.push(("img_projection.2.weight".to_string(), b_w));
666 if let Some(b) = b_b {
667 tensors.push(("img_projection.2.bias".to_string(), b));
668 }
669 vec![
670 Box::new(QuantMethodWrapper(a)),
671 Box::new(candle_nn::Activation::Gelu),
672 Box::new(QuantMethodWrapper(b)),
673 ]
674 }
675 _ => {
676 candle_core::bail!("projection_cls=`{projection_cls}` not implemented.");
677 }
678 };
679
680 let layer_idx = config.img_processor.layer_idx.unwrap_or(-2);
681 let type_feature = config
682 .img_processor
683 .type_feature
684 .clone()
685 .unwrap_or("patch".to_string());
686
687 Ok(Self {
688 wte,
689 image_dim_out,
690 num_img_tokens,
691 glb_gn,
692 sub_gn,
693 layer_idx,
694 type_feature,
695 image_processor,
696 layers: EmbeddingLayers(layers),
697 hd_transform_order,
698 use_hd_transform,
699 vocab_size: config.vocab_size,
700 tensors,
701 })
702 }
703
704 fn get_image_features(&self, pixel_values: &Tensor) -> Result<Tensor> {
705 let hidden_states = self
706 .image_processor
707 .forward_get_hidden_states(&pixel_values.to_dtype(self.wte.embeddings().dtype())?)?;
708 let img_feature =
709 hidden_states[(hidden_states.len() as isize + self.layer_idx) as usize].clone();
710 if self.type_feature == "patch" {
711 img_feature.i((.., 1..))
712 } else if self.type_feature == "cls_patch" {
713 Ok(img_feature)
714 } else {
715 candle_core::bail!("Unsupported image feature type {}", self.type_feature)
716 }
717 }
718
719 #[allow(non_snake_case)]
720 fn forward(
721 &self,
722 input_ids: &Tensor,
723 pixel_values: &Tensor,
724 image_sizes: Option<Vec<(usize, usize)>>,
725 ) -> Result<Tensor> {
726 let input_ids = input_ids.reshape(((), input_ids.dim(D::Minus1)?))?;
727
728 let input_ids_lt = input_ids.lt(0.0f64)?;
729 let input_ids_gt = input_ids.gt(-MAX_INPUT_ID)?;
730 let positions = input_ids_lt.bitwise_and(&input_ids_gt)?.nonzero()?;
732 let target_dev = self.layers.0[0].device();
733 let target_dtype = self.layers.0[0].dtype();
734
735 let mut select = false;
736 let mut hd_transform = None;
738 let mut image_set_tensor = None;
739 if positions.dim(0)? > 0 {
740 select = true;
741 if self.use_hd_transform && image_sizes.is_some() {
743 assert_eq!(pixel_values.dims().len(), 5);
744 let bs = pixel_values.dim(0)?;
745 let img_features = self.get_image_features(&pixel_values.flatten(0, 1)?)?;
746 let base_feat_dim = (img_features.dims()[1] as f32).sqrt() as usize;
747 assert_eq!(base_feat_dim, 24);
748
749 let img_features =
751 img_features.reshape((bs, (), base_feat_dim.pow(2), self.image_dim_out))?;
752 let C = self.image_dim_out;
753 let H = base_feat_dim;
754
755 let mut output_imgs = Vec::new();
756 let mut output_len = Vec::new();
757 for bs_ in 0..bs {
758 let (h, w) = image_sizes.as_ref().unwrap()[bs_];
759 let h = h / 336;
760 let w = w / 336;
761 let B_ = h * w;
762
763 let global_img_feature = img_features.i((bs_, ..1))?;
765
766 let glb_img = global_img_feature
768 .reshape((1, H, H, C))?
769 .reshape((1, H / 2, 2, H / 2, 2, C))?
770 .contiguous()?
771 .permute((0, 1, 3, 2, 4, 5))?
772 .reshape((1, H / 2, H / 2, 4 * C))?
773 .contiguous()?;
774 let temp_glbl_gn = self
775 .sub_gn
776 .as_ref()
777 .expect("Need `sub_gn` if `use_hd_transform`")
778 .repeat((1, H / 2, 1, 1))?;
779
780 let glb_img =
782 Tensor::cat(&[glb_img, temp_glbl_gn], 2)?.reshape((1, (), 4 * C))?;
783
784 let sub_img = img_features.i((bs_, 1..))?;
786
787 let sub_img = sub_img.i(..B_)?;
790
791 let sub_img = sub_img
793 .reshape((B_, H, H, C))?
794 .reshape((B_, H / 2, 2, H / 2, 2, C))?
795 .contiguous()?
796 .permute((0, 1, 3, 2, 4, 5))?
797 .reshape((B_, (), 4 * C))?
798 .contiguous()?;
799 let sub_img = sub_img
800 .reshape(BigShapeWithOneHole((1usize, h, w, 12usize, 12usize, ())))?
801 .permute((0, 1, 3, 2, 4, 5))?
802 .reshape((1, h * 12, w * 12, 4 * C))?;
803 let temp_sub_gn = self
804 .sub_gn
805 .as_ref()
806 .expect("Need `sub_gn` if `use_hd_transform`")
807 .repeat((1, h * 12, 1, 1))?;
808
809 let sub_img =
810 Tensor::cat(&[sub_img, temp_sub_gn], 2)?.reshape((1, (), 4 * C))?;
811
812 match self.hd_transform_order.as_str() {
815 "glb_sub" => {
816 output_imgs.push(Tensor::cat(
817 &[
818 glb_img,
819 self.glb_gn
820 .as_ref()
821 .expect("Need `glb_gn` if `use_hd_transform`")
822 .clone(),
823 sub_img,
824 ],
825 1,
826 )?);
827 }
828 "sub_glb" => {
829 output_imgs.push(Tensor::cat(
830 &[
831 sub_img,
832 self.glb_gn
833 .as_ref()
834 .expect("Need `glb_gn` if `use_hd_transform`")
835 .clone(),
836 glb_img,
837 ],
838 1,
839 )?);
840 }
841 other => {
842 candle_core::bail!("Invalid hd_transform_order=`{other}`");
843 }
844 }
845
846 let temp_len = (h * w + 1) * 144 + 1 + (h + 1) * 12;
847 assert_eq!(temp_len, output_imgs.last().unwrap().dims()[1]);
848 output_len.push(temp_len);
849 }
850
851 hd_transform = Some(output_len);
852 let mut image_set_tensor_inner = Vec::new();
853 for img in output_imgs {
854 let layerout = self
855 .layers
856 .forward(&img.to_device(&target_dev)?.to_dtype(target_dtype)?)?;
857 image_set_tensor_inner.push(layerout);
858 }
859 image_set_tensor = Some(Either::Left(image_set_tensor_inner));
860 } else if pixel_values.dims().len() == 4 {
861 let tt = self
862 .get_image_features(pixel_values)?
863 .to_device(&target_dev)?
864 .to_dtype(target_dtype)?
865 .reshape(((), self.image_dim_out))?;
866 let image_set_tensor_inner = self.layers.forward(&tt)?;
867 image_set_tensor = Some(Either::Right(image_set_tensor_inner));
868 } else if pixel_values.dims().len() == 3 {
869 let tt = pixel_values
870 .to_device(&target_dev)?
871 .to_dtype(target_dtype)?
872 .reshape(((), self.image_dim_out))?;
873 let image_set_tensor_inner = self.layers.forward(&tt)?;
874 image_set_tensor = Some(Either::Right(image_set_tensor_inner));
875 } else {
876 unreachable!()
877 }
878 }
879
880 let input_ids = input_ids.clamp(0.0, self.vocab_size as f64)?;
881 let mut hidden_states = self.wte.forward(&input_ids)?;
882 if select {
883 match (hd_transform, image_set_tensor) {
884 (Some(output_lens), Some(Either::Left(image_set_tensors))) => {
885 let mut idx = 0;
886 for (i, cnt) in output_lens.into_iter().enumerate() {
887 let img_set_tensor = image_set_tensors[i]
888 .to_device(&target_dev)?
889 .to_dtype(target_dtype)?;
890 let p_0 = positions.i((idx, 0))?.to_scalar::<u32>()? as usize;
892 let p_1 = positions.i((idx, 1))?.to_scalar::<u32>()? as usize;
893 hidden_states = hidden_states.slice_assign(
894 &[&p_0, &(p_1..p_1 + cnt), &(..img_set_tensor.dims()[2])],
895 &img_set_tensor,
896 )?;
897 idx += cnt;
898 }
899 }
900 (None, Some(Either::Right(image_set_tensor))) => {
901 let mut idx = 0;
902 for i in 0..pixel_values.dim(0)? {
905 let cnt = self.num_img_tokens;
906 let img_set_tensor = image_set_tensor
907 .i(i * cnt..(i + 1) * cnt)?
908 .to_device(&target_dev)?
909 .to_dtype(target_dtype)?;
910 let p_0 = positions.i((idx, 0))?.to_scalar::<u32>()? as usize;
911 let p_1 = positions.i((idx, 1))?.to_scalar::<u32>()? as usize;
912 hidden_states = hidden_states.slice_assign(
914 &[&p_0, &(p_1..p_1 + cnt), &(..img_set_tensor.dims()[2])],
915 &img_set_tensor,
916 )?;
917 idx += cnt;
918 }
919 }
920 _ => unreachable!(),
921 }
922 }
923
924 Ok(hidden_states)
925 }
926
927 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
928 let uvb = UnVarBuilder::new();
929
930 if let Some(glb_gn) = self.glb_gn.clone() {
931 uvb.add_tensor("glb_GN", glb_gn);
932 }
933 if let Some(sub_gn) = self.sub_gn.clone() {
934 uvb.add_tensor("sub_GN", sub_gn);
935 }
936 uvb.extend(self.tensors.clone());
937 uvb.pp("img_processor.vision_model")
938 .extend(self.image_processor.residual_tensors());
939
940 uvb.to_safetensors()
941 }
942}
943
944pub struct Model {
947 vision_embed_tokens: ImageEmbedding,
948 embed_tokens: candle_nn::Embedding,
949 layers: Vec<DecoderLayer>,
950 norm: RmsNorm,
951 lm_head: Arc<dyn QuantMethod>,
952 device: Device,
953 cache: EitherCache,
954 max_seq_len: usize,
955 mapper: Box<dyn DeviceMapper + Send + Sync>,
956 sliding_window: Option<usize>,
957 cfg: ModelConfigMetadata,
958}
959
960impl Model {
961 pub fn new(
962 cfg: &Config,
963 vb: ShardedVarBuilder,
964 _is_gptx: bool,
965 normal_loading_metadata: NormalLoadingMetadata,
966 attention_mechanism: AttentionImplementation,
967 ) -> Result<Self> {
968 let mapper = normal_loading_metadata.mapper;
969 let vb_m = vb.pp("model");
970
971 let embed_tokens = layers::embedding(
972 cfg.vocab_size,
973 cfg.hidden_size,
974 mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
975 &cfg.quantization_config,
976 )?;
977 let vision_embed_tokens = ImageEmbedding::new(
978 cfg,
979 embed_tokens.clone(),
980 &cfg.embd_layer,
981 mapper.set_nm_device(vb_m.pp("vision_embed_tokens"), false),
982 )?;
983 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
984 let vb_l = vb_m.pp("layers");
985 let mut ropes = HashMap::new();
986 for layer_idx in 0..cfg.num_hidden_layers {
987 let device = mapper
988 .device_for(layer_idx, false)
989 .unwrap_or(&normal_loading_metadata.real_device);
990 ropes.insert(
991 device.location(),
992 Arc::new(PhiRotaryEmbedding::new(vb.dtype(), cfg.clone(), device)?),
993 );
994 }
995 for layer_idx in NiceProgressBar::<_, 'b'>(
996 0..cfg.num_hidden_layers,
997 "Loading repeating layers",
998 &normal_loading_metadata.multi_progress,
999 ) {
1000 let device = mapper
1001 .device_for(layer_idx, false)
1002 .unwrap_or(&normal_loading_metadata.real_device);
1003 let rotary_emb = ropes
1004 .get(&device.location())
1005 .expect("No RoPE for device location!")
1006 .clone();
1007 let paged_attn = match &attention_mechanism {
1008 AttentionImplementation::Eager => None,
1009 AttentionImplementation::PagedAttention => {
1010 Some(PagedAttention::new(cfg.head_dim(), device, None)?)
1011 }
1012 };
1013 let layer = DecoderLayer::new(
1014 rotary_emb.clone(),
1015 cfg,
1016 vb_l.pp(layer_idx),
1017 &*mapper,
1018 layer_idx,
1019 normal_loading_metadata.loading_isq,
1020 paged_attn,
1021 )?;
1022 layers.push(layer)
1023 }
1024 let norm = RmsNorm::new(
1025 cfg.hidden_size,
1026 cfg.rms_norm_eps,
1027 mapper.set_nm_device(vb_m.pp("norm"), false),
1028 )?;
1029 let lm_head = if !cfg.tie_word_embeddings {
1030 ReplicatedLayer::new(
1031 cfg.hidden_size,
1032 cfg.vocab_size,
1033 &None,
1034 false,
1035 mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
1036 )?
1037 } else {
1038 ReplicatedLayer::from_linear(candle_nn::Linear::new(
1039 mapper.cast_nm_device(
1040 embed_tokens.embeddings(),
1041 normal_loading_metadata.loading_isq,
1042 )?,
1043 None,
1044 ))?
1045 };
1046
1047 Ok(Self {
1048 vision_embed_tokens,
1049 layers,
1050 norm,
1051 lm_head,
1052 device: normal_loading_metadata.real_device,
1053 cache: EitherCache::Normal(NormalCache::new_sliding(
1054 cfg.num_hidden_layers,
1055 cfg.max_position_embeddings,
1056 cfg.sliding_window,
1057 )),
1058 max_seq_len: cfg.max_position_embeddings,
1059 sliding_window: cfg.sliding_window,
1060 embed_tokens,
1061 cfg: ModelConfigMetadata {
1062 max_seq_len: cfg.max_position_embeddings,
1063 num_layers: cfg.num_hidden_layers,
1064 hidden_size: cfg.hidden_size,
1065 num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
1066 num_kv_heads: (cfg.num_key_value_heads / mapper.get_comm_for(0)?.world_size())
1067 .max(1),
1068 sliding_window: cfg.sliding_window,
1069 k_head_dim: cfg.head_dim(),
1070 v_head_dim: cfg.head_dim(),
1071 },
1072 mapper,
1073 })
1074 }
1075
1076 #[allow(clippy::too_many_arguments)]
1077 pub fn forward(
1078 &self,
1079 input_ids: &Tensor,
1080 pixel_values: Option<Tensor>,
1081 seqlen_offsets: &[usize],
1082 position_ids: &[usize],
1083 context_lens: Vec<(usize, usize)>,
1084 image_sizes: Option<Vec<(usize, usize)>>,
1085 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
1086 flash_params: &FlashParams,
1087 ) -> Result<Tensor> {
1088 let mut xs = if let Some(ref pixel_values) = pixel_values {
1089 self.vision_embed_tokens
1090 .forward(input_ids, pixel_values, image_sizes)?
1091 } else {
1092 self.embed_tokens.forward(input_ids)?
1093 };
1094 let cache = &mut self.cache.normal().0;
1095 let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
1096 input_ids,
1097 metadata
1098 .as_ref()
1099 .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
1100 .unwrap_or(&*cache as &dyn PastKvLenCache),
1101 self.sliding_window,
1102 xs.dtype(),
1103 self.cfg.num_attn_heads,
1104 )?;
1105 let attention_mask = attention_mask.filter(|_| {
1106 metadata
1107 .as_ref()
1108 .map(|(_, meta)| meta.is_first_prompt_chunk)
1109 .unwrap_or(true)
1110 });
1111
1112 for (i, layer) in self.layers.iter().enumerate() {
1113 xs = self.mapper.map(xs, i)?;
1114 xs = layer.forward(
1115 &xs,
1116 attention_mask
1117 .as_ref()
1118 .map(|m| m.to_device(xs.device()).unwrap())
1119 .as_ref(),
1120 seqlen_offsets,
1121 position_ids,
1122 &mut cache[i],
1123 metadata
1124 .as_ref()
1125 .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
1126 flash_params,
1127 )?
1128 }
1129 let xs = xs.to_device(&self.device)?;
1130 let mut xs = xs.apply(&self.norm)?;
1131 if let Some(t) = self.lm_head.quantized_act_type() {
1132 xs = xs.to_dtype(t)?;
1133 }
1134 extract_logits(&MatMul.qmethod_matmul(&xs, &*self.lm_head)?, context_lens)
1135 }
1136}
1137
1138impl IsqModel for Model {
1139 fn get_layers(
1140 &mut self,
1141 ) -> (
1142 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
1143 &dyn DeviceMapper,
1144 ) {
1145 let mut tensors = Vec::new();
1146 tensors.push((&mut self.lm_head, None));
1147 for (i, layer) in self.layers.iter_mut().enumerate() {
1148 tensors.push((&mut layer.self_attn.qkv_proj, Some(i)));
1149 tensors.push((&mut layer.self_attn.o_proj, Some(i)));
1150 tensors.extend(
1151 layer
1152 .mlp
1153 .get_isq_layers()
1154 .into_iter()
1155 .map(|m| (m, Some(i)))
1156 .collect::<Vec<_>>(),
1157 );
1158 }
1159 (tensors, &*self.mapper)
1160 }
1161
1162 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
1163 let uvb = UnVarBuilder::new();
1164
1165 let uvb_m = uvb.pp("model");
1166 uvb_m.pp("embed_tokens").add(&self.embed_tokens);
1167 uvb_m.pp("norm").add(&self.norm);
1168 uvb_m
1169 .pp("vision_embed_tokens")
1170 .extend(self.vision_embed_tokens.residual_tensors());
1171
1172 for (layer_idx, layer) in self.layers.iter().enumerate() {
1173 let uvb_l = uvb_m.pp("layers").pp(layer_idx);
1174 uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
1175 uvb_l
1176 .pp("post_attention_layernorm")
1177 .add(&layer.post_attention_layernorm);
1178 }
1179
1180 uvb.to_safetensors()
1181 }
1182}
1183
1184#[derive(Default)]
1185pub(crate) struct Phi3VisionSpecificArgs {
1186 pub image_sizes: Option<Vec<(usize, usize)>>,
1187}
1188
1189impl VisionModel for Model {
1190 fn forward(
1191 &self,
1192 input_ids: &Tensor,
1193 pixel_values: Option<Tensor>,
1194 seqlen_offsets: &[usize],
1195 context_lens: Vec<(usize, usize)>,
1196 position_ids: Vec<usize>,
1197 model_specific_args: Box<dyn Any>,
1198 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
1199 flash_params: &FlashParams,
1200 ) -> Result<Tensor> {
1201 let Phi3VisionSpecificArgs { image_sizes } = *model_specific_args
1202 .downcast()
1203 .expect("Cannot downcast into `Phi3VisionSpecificArgs`");
1204 self.forward(
1205 input_ids,
1206 pixel_values,
1207 seqlen_offsets,
1208 &position_ids,
1209 context_lens,
1210 image_sizes,
1211 metadata,
1212 flash_params,
1213 )
1214 }
1215 fn cache(&self) -> &EitherCache {
1216 &self.cache
1217 }
1218 fn cache_mut(&mut self) -> &mut EitherCache {
1219 &mut self.cache
1220 }
1221 fn device(&self) -> &Device {
1222 &self.device
1223 }
1224 fn max_seq_len(&self) -> usize {
1225 self.max_seq_len
1226 }
1227 fn has_conv2d(&self) -> bool {
1228 true
1229 }
1230 fn config(&self) -> &ModelConfigMetadata {
1231 &self.cfg
1232 }
1233 fn default_model_specific_args(&self, _input_ids: &Tensor) -> Box<dyn Any> {
1234 Box::new(Phi3VisionSpecificArgs::default())
1235 }
1236}
1237
1238impl AnyMoeBaseModelMixin for Model {
1239 fn get_mlps(&self) -> Vec<&dyn MlpLayer> {
1240 let mut mlps = Vec::new();
1241 for layer in &self.layers {
1242 mlps.push(&*layer.mlp);
1243 }
1244 mlps
1245 }
1246 fn get_mlps_mut(&mut self) -> Vec<&mut Box<dyn MlpLayer>> {
1247 let mut mlps = Vec::new();
1248 for layer in &mut self.layers {
1249 mlps.push(&mut layer.mlp);
1250 }
1251 mlps
1252 }
1253 fn create_anymoe_layers(
1254 &mut self,
1255 additional_vbs: Vec<ShardedVarBuilder>,
1256 config: AnyMoeConfig,
1257 (prefix, mlp): (String, String),
1258 mut layers: Vec<usize>,
1259 expert_type: AnyMoeExpertType,
1260 gate_vb: Option<ShardedVarBuilder>,
1261 ) -> Result<()> {
1262 let mut experts: Vec<Vec<Box<dyn MlpLayer>>> = Vec::new();
1263 if layers.is_empty() {
1264 layers = (0..self.layers.len()).collect::<Vec<_>>();
1265 }
1266 for _ in 0..layers.len() {
1267 experts.push(Vec::new());
1268 }
1269 for vb in additional_vbs {
1270 let vb = vb.pp(&prefix);
1271 for (layer, row) in experts.iter_mut().enumerate() {
1272 if !layers.contains(&layer) {
1273 continue;
1274 }
1275
1276 let intermediate_size = self.layers[layer].mlp.get_params()[1];
1277 let hidden_size = self.layers[layer].mlp.get_params()[0];
1278 match expert_type {
1279 AnyMoeExpertType::FineTuned => {
1280 row.push(Box::new(Mlp::new(
1281 &Config {
1282 intermediate_size: self.layers[layer].mlp.get_params()[1],
1283 hidden_size: self.layers[layer].mlp.get_params()[0],
1284 ..Default::default()
1285 },
1286 vb.pp(layer).pp(&mlp),
1287 )?));
1288 }
1289 AnyMoeExpertType::LoraAdapter {
1290 rank,
1291 alpha,
1292 ref target_modules,
1293 } => {
1294 let vb_mlp = vb.pp(layer).pp(&mlp);
1295
1296 let gate_up_proj_delta =
1297 if target_modules.contains(&"gate_up_proj".to_string()) {
1298 Some(get_delta_from_lora_ab!(
1299 vb_mlp,
1300 rank,
1301 alpha,
1302 (hidden_size, 2 * intermediate_size),
1303 "gate_up_proj"
1304 ))
1305 } else {
1306 None
1307 };
1308 let down_proj_delta = if target_modules.contains(&"down_proj".to_string()) {
1309 Some(get_delta_from_lora_ab!(
1310 vb_mlp,
1311 rank,
1312 alpha,
1313 (hidden_size, intermediate_size),
1314 "down_proj"
1315 ))
1316 } else {
1317 None
1318 };
1319
1320 row.push(
1321 self.layers[layer]
1322 .mlp
1323 .new_added_delta(vec![gate_up_proj_delta, down_proj_delta])?,
1324 );
1325 }
1326 }
1327 }
1328 }
1329 for (layer, expert) in layers.into_iter().zip(experts) {
1330 let mut experts_all = vec![self.layers[layer].mlp.clone()];
1331 experts_all.extend(expert);
1332 let (dtype, device) = self.layers[layer].mlp.dtype_device();
1333 self.layers[layer].mlp = Box::new(MoeMlp::new(
1334 experts_all,
1335 config.clone(),
1336 dtype,
1337 &device,
1338 layer,
1339 gate_vb.as_ref(),
1340 )?);
1341 }
1342 Ok(())
1343 }
1344 fn amoe_supported(&self) -> bool {
1345 true
1346 }
1347}