1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3pub(crate) mod phi3_inputs_processor;
4
5use candle_core::{
8 shape::ShapeWithOneHole, DType, Device, IndexOp, Module, Result, Shape, Tensor, D,
9};
10use either::Either;
11use mistralrs_quant::{
12 BitWiseOp, NonZeroOp, QuantMethod, QuantizedConfig, ReplicatedLayer, ShardedVarBuilder,
13};
14use std::{any::Any, collections::HashMap, fmt::Debug, sync::Arc};
15
16use crate::{
17 amoe::{AnyMoeBaseModelMixin, AnyMoeTrainableLayer, MlpLayer, MoeMlp},
18 attention::SdpaParams,
19 device_map::DeviceMapper,
20 get_delta_from_lora_ab,
21 layers::{
22 self, Activation, CausalMasker, MatMul, PhiRopeConfig, PhiRopeScalingConfig,
23 PhiRotaryEmbedding, RmsNorm, Sdpa,
24 },
25 layers_masker::PastKvLenCache,
26 paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
27 pipeline::{
28 extract_logits,
29 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
30 EitherCache, IsqModel, KvCache, NormalCache, NormalLoadingMetadata, VisionModel,
31 },
32 serde_default_fn,
33 utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
34 vision_models::clip::{ClipConfig, ClipVisionTransformer},
35 AnyMoeConfig, AnyMoeExpertType,
36};
37
38use super::clip;
39
40#[derive(Debug, Clone, serde::Deserialize, Default)]
41pub struct EmbedLayerConfig {
42 pub hd_transform_order: Option<String>,
43 pub projection_cls: Option<String>,
44 pub use_hd_transform: Option<bool>,
45 pub with_learnable_separator: Option<bool>,
46}
47
48#[derive(Debug, Clone, serde::Deserialize, Default)]
49pub struct ImageProcessorConfig {
50 pub image_dim_out: usize,
51 pub name: String,
52 pub num_img_tokens: usize,
53 pub layer_idx: Option<isize>,
54 pub type_feature: Option<String>,
55}
56
57serde_default_fn!(bool, word_emb_default, false);
58
59#[derive(Debug, Clone, serde::Deserialize, Default)]
60pub struct Config {
61 pub vocab_size: usize,
62 pub hidden_act: Activation,
63 pub hidden_size: usize,
64 pub intermediate_size: usize,
65 pub num_hidden_layers: usize,
66 pub num_attention_heads: usize,
67 pub num_key_value_heads: usize,
68 pub rms_norm_eps: f64,
69 pub rope_theta: f64,
70 pub bos_token_id: Option<u32>,
71 pub eos_token_id: Option<u32>,
72 pub rope_scaling: Option<PhiRopeScalingConfig>,
73 pub max_position_embeddings: usize,
74 pub sliding_window: Option<usize>,
75 pub original_max_position_embeddings: usize,
76 pub embd_layer: EmbedLayerConfig,
77 pub img_processor: ImageProcessorConfig,
78 #[serde(alias = "quantization")]
79 pub quantization_config: Option<QuantizedConfig>,
80 #[serde(default = "word_emb_default")]
81 pub tie_word_embeddings: bool,
82}
83
84impl From<Config> for PhiRopeConfig {
85 fn from(val: Config) -> Self {
86 PhiRopeConfig {
87 rope_scaling: val.rope_scaling,
88 max_position_embeddings: val.max_position_embeddings,
89 original_max_position_embeddings: val.original_max_position_embeddings,
90 rope_theta: val.rope_theta,
91 head_dim: val.hidden_size / val.num_attention_heads,
92 partial_rotary_factor: None,
93 }
94 }
95}
96
97impl Config {
98 pub fn head_dim(&self) -> usize {
99 self.hidden_size / self.num_attention_heads
100 }
101}
102
103trait ModuleWithMetadata: Module + Debug + Send + Sync {
104 fn device(&self) -> Device;
105 fn dtype(&self) -> DType;
106}
107
108#[derive(Debug)]
109struct QuantMethodWrapper(Arc<dyn QuantMethod>);
110
111impl Module for QuantMethodWrapper {
112 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
113 self.0.forward(xs)
114 }
115}
116
117impl ModuleWithMetadata for QuantMethodWrapper {
118 fn device(&self) -> Device {
119 self.0.unquant_weight_bias().unwrap().0.device().clone()
120 }
121 fn dtype(&self) -> DType {
122 self.0.unquant_weight_bias().unwrap().0.dtype()
123 }
124}
125
126impl ModuleWithMetadata for candle_nn::Activation {
127 fn device(&self) -> Device {
128 unreachable!()
129 }
130 fn dtype(&self) -> DType {
131 unreachable!()
132 }
133}
134
135#[derive(Debug)]
136struct BigShapeWithOneHole((usize, usize, usize, usize, usize, ()));
137
138fn hole_size(el_count: usize, prod_d: usize, s: &dyn std::fmt::Debug) -> Result<usize> {
139 if prod_d == 0 {
140 candle_core::bail!("cannot reshape tensor of {el_count} elements to {s:?}")
141 }
142 if el_count % prod_d != 0 {
143 candle_core::bail!("cannot reshape tensor with {el_count} elements to {s:?}")
144 }
145 Ok(el_count / prod_d)
146}
147
148impl ShapeWithOneHole for BigShapeWithOneHole {
149 fn into_shape(self, el_count: usize) -> Result<Shape> {
150 let (d1, d2, d3, d4, d5, ()) = self.0;
151 let d = hole_size(el_count, d1 * d2 * d3 * d4 * d5, &self)?;
152 Ok((d1, d2, d3, d4, d5, d).into())
153 }
154}
155
156struct Attention {
159 qkv_proj: Arc<dyn QuantMethod>,
160 o_proj: Arc<dyn QuantMethod>,
161 num_heads: usize,
162 num_kv_heads: usize,
163 head_dim: usize,
164 rotary_emb: Arc<PhiRotaryEmbedding>,
165 paged_attn: Option<PagedAttention>,
166 sdpa_params: SdpaParams,
167}
168
169impl Attention {
170 fn new(
171 rotary_emb: Arc<PhiRotaryEmbedding>,
172 cfg: &Config,
173 vb: ShardedVarBuilder,
174 paged_attn: Option<PagedAttention>,
175 ) -> Result<Self> {
176 let num_heads = cfg.num_attention_heads;
177 let num_kv_heads = cfg.num_key_value_heads;
178 let head_dim = cfg.head_dim();
179 let op_size = num_heads * head_dim + 2 * num_kv_heads * head_dim;
180
181 let qkv_proj = mistralrs_quant::linear_no_bias(
183 cfg.hidden_size,
184 op_size,
185 &cfg.quantization_config,
186 vb.pp("qkv_proj"),
187 )?;
188
189 let o_proj = mistralrs_quant::linear_no_bias(
190 num_heads * head_dim,
191 cfg.hidden_size,
192 &cfg.quantization_config,
193 vb.pp("o_proj"),
194 )?;
195
196 Ok(Self {
197 qkv_proj,
198 o_proj,
199 rotary_emb,
200 num_heads,
201 num_kv_heads,
202 head_dim,
203 paged_attn,
204 sdpa_params: SdpaParams {
205 n_kv_groups: num_heads / num_kv_heads,
206 softcap: None,
207 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
208 sliding_window: cfg.sliding_window,
209 },
210 })
211 }
212
213 #[allow(clippy::too_many_arguments)]
214 fn forward(
215 &self,
216 xs: &Tensor,
217 attention_mask: Option<&Tensor>,
218 seqlen_offsets: &[usize],
219 position_ids: &[usize],
220 kv_cache: &mut KvCache,
221 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
222 flash_params: &FlashParams,
223 ) -> Result<Tensor> {
224 let (b_sz, q_len, _) = xs.dims3()?;
225
226 let original_dtype = xs.dtype();
227 let mut xs = xs.clone();
228 if let Some(t) = self.qkv_proj.quantized_act_type() {
229 xs = xs.to_dtype(t)?;
230 }
231 let mut qkv = MatMul.qmethod_matmul(&xs, &*self.qkv_proj)?;
232 if self.qkv_proj.quantized_act_type().is_some() {
233 qkv = qkv.to_dtype(original_dtype)?;
234 }
235 let query_pos = self.num_heads * self.head_dim;
236 let q = qkv.narrow(D::Minus1, 0, query_pos)?;
237 let k = qkv.narrow(D::Minus1, query_pos, self.num_kv_heads * self.head_dim)?;
238 let v = qkv.narrow(
239 D::Minus1,
240 query_pos + self.num_kv_heads * self.head_dim,
241 self.num_kv_heads * self.head_dim,
242 )?;
243
244 let (q, k, v) = if q_len != 1 {
245 let q = q
246 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
247 .transpose(1, 2)?;
248 let k = k
249 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
250 .transpose(1, 2)?;
251 let v = v
252 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
253 .transpose(1, 2)?;
254 (q, k, v)
255 } else {
256 let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
257 let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
258 let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
259 (q, k, v)
260 };
261
262 let (q, k) = self
263 .rotary_emb
264 .forward(&q, &k, seqlen_offsets, position_ids)?;
265
266 let mut attn_output = match &self.paged_attn {
267 Some(paged_attn) => match metadata {
268 Some(((key_cache, value_cache), input_metadata)) => paged_attn.forward(
269 &q,
270 &k.contiguous()?,
271 &v.contiguous()?,
272 attention_mask,
273 Some(key_cache),
274 Some(value_cache),
275 input_metadata,
276 &self.sdpa_params,
277 Some(flash_params),
278 )?,
279 None => {
280 let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
283 assert!(attention_mask.is_some());
285 paged_attn.forward(
286 &q,
287 &k.contiguous()?,
288 &v.contiguous()?,
289 attention_mask,
290 None,
291 None,
292 &input_metadata,
293 &self.sdpa_params,
294 Some(flash_params),
295 )?
296 }
297 },
298 None => {
299 let (k, v) = kv_cache.append(&k, &v)?;
300
301 Sdpa.run_attention(
302 &q,
303 &k,
304 &v,
305 attention_mask,
306 Some(flash_params),
307 &self.sdpa_params,
308 )?
309 }
310 };
311
312 if let Some(t) = self.qkv_proj.quantized_act_type() {
313 attn_output = attn_output.to_dtype(t)?;
314 }
315 attn_output = if attention_mask.is_some() {
316 attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?
317 } else {
318 attn_output.reshape((b_sz, q_len, ()))?
319 };
320 let mut res = MatMul.qmethod_matmul(&attn_output, &*self.o_proj)?;
321 if self.qkv_proj.quantized_act_type().is_some() {
322 res = res.to_dtype(original_dtype)?;
323 }
324 Ok(res)
325 }
326}
327
328#[derive(Clone)]
329struct Mlp {
330 gate_up_proj: Arc<dyn QuantMethod>,
331 down_proj: Arc<dyn QuantMethod>,
332 act_fn: Activation,
333 i_size: usize,
334 params: Vec<usize>,
335}
336
337impl Mlp {
338 fn new(cfg: &Config, vb: ShardedVarBuilder) -> Result<Self> {
339 let hidden_size = cfg.hidden_size;
340 let i_size = cfg.intermediate_size;
341
342 let gate_up_proj = mistralrs_quant::linear_no_bias(
344 hidden_size,
345 2 * i_size,
346 &cfg.quantization_config,
347 vb.pp("gate_up_proj"),
348 )?;
349
350 let down_proj = mistralrs_quant::linear_no_bias(
351 i_size,
352 hidden_size,
353 &cfg.quantization_config,
354 vb.pp("down_proj"),
355 )?;
356
357 Ok(Self {
358 gate_up_proj,
359 down_proj,
360 act_fn: cfg.hidden_act,
361 i_size,
362 params: vec![hidden_size, i_size],
363 })
364 }
365}
366
367impl AnyMoeTrainableLayer for Mlp {}
368
369impl MlpLayer for Mlp {
370 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
371 let original_dtype = xs.dtype();
372 let mut xs = xs.clone();
373 if let Some(t) = self.gate_up_proj.quantized_act_type() {
374 xs = xs.to_dtype(t)?;
375 }
376 let up_states = MatMul.qmethod_matmul(&xs, &*self.gate_up_proj)?;
377 let gate = up_states.narrow(D::Minus1, 0, self.i_size)?;
378 let up_states = up_states.narrow(D::Minus1, self.i_size, self.i_size)?;
379 let up_states = (up_states * gate.apply(&self.act_fn))?;
380 let mut res = MatMul.qmethod_matmul(&up_states, &*self.down_proj)?;
381 if self.gate_up_proj.quantized_act_type().is_some() {
382 res = res.to_dtype(original_dtype)?;
383 }
384 Ok(res)
385 }
386 fn get_isq_layers(&mut self) -> Vec<&mut Arc<dyn QuantMethod>> {
387 vec![&mut self.gate_up_proj, &mut self.down_proj]
388 }
389 fn clone(&self) -> Box<dyn MlpLayer> {
390 Box::new(Clone::clone(self))
391 }
392 fn get_params(&self) -> &[usize] {
393 &self.params
394 }
395 fn hidden_act(&self) -> Activation {
396 self.act_fn
397 }
398 fn new_added_delta(&self, deltas: Vec<Option<Tensor>>) -> Result<Box<dyn MlpLayer>> {
400 let new_gate_up = if let Some(ref delta) = deltas[0] {
401 self.gate_up_proj.add_delta_w(delta)?
402 } else {
403 self.gate_up_proj.clone()
404 };
405 let new_down = if let Some(ref delta) = deltas[1] {
406 self.down_proj.add_delta_w(delta)?
407 } else {
408 self.down_proj.clone()
409 };
410
411 Ok(Box::new(Self {
412 gate_up_proj: new_gate_up,
413 down_proj: new_down,
414 act_fn: self.act_fn,
415 i_size: self.i_size,
416 params: self.params.clone(),
417 }))
418 }
419
420 fn dtype_device(&self) -> (DType, Device) {
421 self.gate_up_proj.dtype_and_device()
422 }
423}
424
425struct DecoderLayer {
426 self_attn: Attention,
427 mlp: Box<dyn MlpLayer>,
428 input_layernorm: RmsNorm,
429 post_attention_layernorm: RmsNorm,
430}
431
432impl DecoderLayer {
433 fn new(
434 rotary_emb: Arc<PhiRotaryEmbedding>,
435 cfg: &Config,
436 vb: ShardedVarBuilder,
437 mapper: &dyn DeviceMapper,
438 layer_idx: usize,
439 loading_isq: bool,
440 paged_attn: Option<PagedAttention>,
441 ) -> Result<Self> {
442 let self_attn = Attention::new(
443 rotary_emb,
444 cfg,
445 mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
446 paged_attn,
447 )?;
448 let mlp = Mlp::new(cfg, mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq))?;
449 let input_layernorm = RmsNorm::new(
450 cfg.hidden_size,
451 cfg.rms_norm_eps,
452 mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
453 )?;
454 let post_attention_layernorm = RmsNorm::new(
455 cfg.hidden_size,
456 cfg.rms_norm_eps,
457 mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
458 )?;
459 Ok(Self {
460 self_attn,
461 mlp: Box::new(mlp),
462 input_layernorm,
463 post_attention_layernorm,
464 })
465 }
466
467 #[allow(clippy::too_many_arguments)]
468 fn forward(
469 &self,
470 xs: &Tensor,
471 attention_mask: Option<&Tensor>,
472 seqlen_offsets: &[usize],
473 position_ids: &[usize],
474 kv_cache: &mut KvCache,
475 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
476 flash_params: &FlashParams,
477 ) -> Result<Tensor> {
478 let residual = xs;
479 let xs = self.input_layernorm.forward(xs)?;
480 let xs = self
481 .self_attn
482 .forward(
483 &xs,
484 attention_mask,
485 seqlen_offsets,
486 position_ids,
487 kv_cache,
488 metadata,
489 flash_params,
490 )
491 .unwrap();
492 let xs = (xs + residual)?;
493 let residual = &xs;
494 let xs = self
495 .mlp
496 .forward(&xs.apply(&self.post_attention_layernorm)?)?;
497 residual + xs
498 }
499}
500
501const MAX_INPUT_ID: f64 = 1e9;
506
507#[derive(Debug)]
508struct EmbeddingLayers(Vec<Box<dyn ModuleWithMetadata>>);
509
510impl Module for EmbeddingLayers {
511 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
512 let mut xs = xs.clone();
513 for layer in &self.0 {
514 xs = layer.forward(&xs)?;
515 }
516 Ok(xs)
517 }
518}
519
520#[derive(Debug)]
521pub struct ImageEmbedding {
522 wte: candle_nn::Embedding,
523 image_dim_out: usize,
524 num_img_tokens: usize,
525 glb_gn: Option<Tensor>,
526 sub_gn: Option<Tensor>,
527 layers: EmbeddingLayers,
528 type_feature: String,
529 layer_idx: isize,
530 image_processor: ClipVisionTransformer,
531 hd_transform_order: String,
532 use_hd_transform: bool,
533 vocab_size: usize,
534 tensors: Vec<(String, Tensor)>,
535}
536
537pub(crate) const PHI3V_CLIP_CONFIG: ClipConfig = ClipConfig {
538 hidden_act: clip::Activation::QuickGelu,
539 hidden_size: 1024,
540 image_size: 336,
541 intermediate_size: 4096,
542 num_attention_heads: 16,
543 num_channels: 3,
544 num_hidden_layers: 24,
545 patch_size: 14,
546};
547
548impl ImageEmbedding {
549 fn new(
550 config: &Config,
551 wte: candle_nn::Embedding,
552 embed_config: &EmbedLayerConfig,
553 vb: ShardedVarBuilder,
554 ) -> Result<Self> {
555 let hidden_size = config.hidden_size;
556 if config.img_processor.name != "clip_vision_model" {
557 candle_core::bail!(
558 "img_processor=`{}` nor supported.",
559 config.img_processor.name
560 );
561 }
562 let image_dim_out = config.img_processor.image_dim_out;
563 let num_img_tokens = config.img_processor.num_img_tokens;
564
565 let image_processor =
567 ClipVisionTransformer::new(vb.pp("img_processor.vision_model"), &PHI3V_CLIP_CONFIG)?;
568
569 let use_hd_transform = embed_config.use_hd_transform.unwrap_or(false);
571 let with_learnable_separator = embed_config.with_learnable_separator.unwrap_or(false);
572 let hd_transform_order = embed_config
573 .hd_transform_order
574 .clone()
575 .unwrap_or("glb_sub".to_string());
576 assert_eq!(use_hd_transform, with_learnable_separator);
577 let (glb_gn, sub_gn) = if with_learnable_separator {
578 let glb_gn = vb.get((1, 1, image_dim_out * 4), "glb_GN")?;
579 let sub_gn = vb.get((1, 1, 1, image_dim_out * 4), "sub_GN")?;
580 (Some(glb_gn), Some(sub_gn))
581 } else {
582 (None, None)
583 };
584
585 let projection_cls = embed_config
587 .projection_cls
588 .clone()
589 .unwrap_or("linear".to_string());
590
591 let mut tensors = Vec::new();
592 let layers: Vec<Box<dyn ModuleWithMetadata>> =
593 match (projection_cls.as_str(), use_hd_transform) {
594 ("linear", _) => {
595 let a = mistralrs_quant::linear_b(
596 image_dim_out,
597 hidden_size,
598 true,
599 &None,
600 vb.pp("img_projection"),
601 )?;
602 let (a_w, a_b) = a.unquant_weight_bias().unwrap();
603 tensors.push(("img_projection.weight".to_string(), a_w));
604 if let Some(b) = a_b {
605 tensors.push(("img_projection.bias".to_string(), b));
606 }
607 vec![Box::new(QuantMethodWrapper(a))]
608 }
609 ("mlp", true) => {
610 let dim_proj = hidden_size;
611 let a = mistralrs_quant::linear_b(
612 image_dim_out * 4,
613 dim_proj,
614 true,
615 &None,
616 vb.pp("img_projection.0"),
617 )?;
618 let (a_w, a_b) = a.unquant_weight_bias().unwrap();
619 tensors.push(("img_projection.0.weight".to_string(), a_w));
620 if let Some(b) = a_b {
621 tensors.push(("img_projection.0.bias".to_string(), b));
622 }
623 let b = mistralrs_quant::linear_b(
624 dim_proj,
625 dim_proj,
626 true,
627 &None,
628 vb.pp("img_projection.2"),
629 )?;
630 let (b_w, b_b) = b.unquant_weight_bias().unwrap();
631 tensors.push(("img_projection.2.weight".to_string(), b_w));
632 if let Some(b) = b_b {
633 tensors.push(("img_projection.2.bias".to_string(), b));
634 }
635 vec![
636 Box::new(QuantMethodWrapper(a)),
637 Box::new(candle_nn::Activation::Gelu),
638 Box::new(QuantMethodWrapper(b)),
639 ]
640 }
641 ("mlp", false) => {
642 let dim_proj = hidden_size;
643 let a = mistralrs_quant::linear_b(
644 image_dim_out,
645 dim_proj,
646 true,
647 &None,
648 vb.pp("img_projection.0"),
649 )?;
650 let (a_w, a_b) = a.unquant_weight_bias().unwrap();
651 tensors.push(("img_projection.0.weight".to_string(), a_w));
652 if let Some(b) = a_b {
653 tensors.push(("img_projection.0.bias".to_string(), b));
654 }
655 let b = mistralrs_quant::linear_b(
656 dim_proj,
657 dim_proj,
658 true,
659 &None,
660 vb.pp("img_projection.2"),
661 )?;
662 let (b_w, b_b) = b.unquant_weight_bias().unwrap();
663 tensors.push(("img_projection.2.weight".to_string(), b_w));
664 if let Some(b) = b_b {
665 tensors.push(("img_projection.2.bias".to_string(), b));
666 }
667 vec![
668 Box::new(QuantMethodWrapper(a)),
669 Box::new(candle_nn::Activation::Gelu),
670 Box::new(QuantMethodWrapper(b)),
671 ]
672 }
673 _ => {
674 candle_core::bail!("projection_cls=`{projection_cls}` not implemented.");
675 }
676 };
677
678 let layer_idx = config.img_processor.layer_idx.unwrap_or(-2);
679 let type_feature = config
680 .img_processor
681 .type_feature
682 .clone()
683 .unwrap_or("patch".to_string());
684
685 Ok(Self {
686 wte,
687 image_dim_out,
688 num_img_tokens,
689 glb_gn,
690 sub_gn,
691 layer_idx,
692 type_feature,
693 image_processor,
694 layers: EmbeddingLayers(layers),
695 hd_transform_order,
696 use_hd_transform,
697 vocab_size: config.vocab_size,
698 tensors,
699 })
700 }
701
702 fn get_image_features(&self, pixel_values: &Tensor) -> Result<Tensor> {
703 let hidden_states = self
704 .image_processor
705 .forward_get_hidden_states(&pixel_values.to_dtype(self.wte.embeddings().dtype())?)?;
706 let img_feature =
707 hidden_states[(hidden_states.len() as isize + self.layer_idx) as usize].clone();
708 if self.type_feature == "patch" {
709 img_feature.i((.., 1..))
710 } else if self.type_feature == "cls_patch" {
711 Ok(img_feature)
712 } else {
713 candle_core::bail!("Unsupported image feature type {}", self.type_feature)
714 }
715 }
716
717 #[allow(non_snake_case)]
718 fn forward(
719 &self,
720 input_ids: &Tensor,
721 pixel_values: &Tensor,
722 image_sizes: Option<Vec<(usize, usize)>>,
723 ) -> Result<Tensor> {
724 let input_ids = input_ids.reshape(((), input_ids.dim(D::Minus1)?))?;
725
726 let input_ids_lt = input_ids.lt(0.0f64)?;
727 let input_ids_gt = input_ids.gt(-MAX_INPUT_ID)?;
728 let positions = input_ids_lt.bitwise_and(&input_ids_gt)?.nonzero()?;
730 let target_dev = self.layers.0[0].device();
731 let target_dtype = self.layers.0[0].dtype();
732
733 let mut select = false;
734 let mut hd_transform = None;
736 let mut image_set_tensor = None;
737 if positions.dim(0)? > 0 {
738 select = true;
739 if self.use_hd_transform {
741 if let Some(image_sizes_ref) = image_sizes.as_ref() {
742 assert_eq!(pixel_values.dims().len(), 5);
743 let bs = pixel_values.dim(0)?;
744 let img_features = self.get_image_features(&pixel_values.flatten(0, 1)?)?;
745 let base_feat_dim = (img_features.dims()[1] as f32).sqrt() as usize;
746 assert_eq!(base_feat_dim, 24);
747
748 let img_features =
750 img_features.reshape((bs, (), base_feat_dim.pow(2), self.image_dim_out))?;
751 let C = self.image_dim_out;
752 let H = base_feat_dim;
753
754 let mut output_imgs = Vec::new();
755 let mut output_len = Vec::new();
756 for (bs_, &(h, w)) in image_sizes_ref.iter().enumerate().take(bs) {
757 let h = h / 336;
758 let w = w / 336;
759 let B_ = h * w;
760
761 let global_img_feature = img_features.i((bs_, ..1))?;
763
764 let glb_img = global_img_feature
766 .reshape((1, H, H, C))?
767 .reshape((1, H / 2, 2, H / 2, 2, C))?
768 .contiguous()?
769 .permute((0, 1, 3, 2, 4, 5))?
770 .reshape((1, H / 2, H / 2, 4 * C))?
771 .contiguous()?;
772 let temp_glbl_gn = self
773 .sub_gn
774 .as_ref()
775 .expect("Need `sub_gn` if `use_hd_transform`")
776 .repeat((1, H / 2, 1, 1))?;
777
778 let glb_img =
780 Tensor::cat(&[glb_img, temp_glbl_gn], 2)?.reshape((1, (), 4 * C))?;
781
782 let sub_img = img_features.i((bs_, 1..))?;
784
785 let sub_img = sub_img.i(..B_)?;
788
789 let sub_img = sub_img
791 .reshape((B_, H, H, C))?
792 .reshape((B_, H / 2, 2, H / 2, 2, C))?
793 .contiguous()?
794 .permute((0, 1, 3, 2, 4, 5))?
795 .reshape((B_, (), 4 * C))?
796 .contiguous()?;
797 let sub_img = sub_img
798 .reshape(BigShapeWithOneHole((1usize, h, w, 12usize, 12usize, ())))?
799 .permute((0, 1, 3, 2, 4, 5))?
800 .reshape((1, h * 12, w * 12, 4 * C))?;
801 let temp_sub_gn = self
802 .sub_gn
803 .as_ref()
804 .expect("Need `sub_gn` if `use_hd_transform`")
805 .repeat((1, h * 12, 1, 1))?;
806
807 let sub_img =
808 Tensor::cat(&[sub_img, temp_sub_gn], 2)?.reshape((1, (), 4 * C))?;
809
810 match self.hd_transform_order.as_str() {
813 "glb_sub" => {
814 output_imgs.push(Tensor::cat(
815 &[
816 glb_img,
817 self.glb_gn
818 .as_ref()
819 .expect("Need `glb_gn` if `use_hd_transform`")
820 .clone(),
821 sub_img,
822 ],
823 1,
824 )?);
825 }
826 "sub_glb" => {
827 output_imgs.push(Tensor::cat(
828 &[
829 sub_img,
830 self.glb_gn
831 .as_ref()
832 .expect("Need `glb_gn` if `use_hd_transform`")
833 .clone(),
834 glb_img,
835 ],
836 1,
837 )?);
838 }
839 other => {
840 candle_core::bail!("Invalid hd_transform_order=`{other}`");
841 }
842 }
843
844 let temp_len = (h * w + 1) * 144 + 1 + (h + 1) * 12;
845 assert_eq!(temp_len, output_imgs.last().unwrap().dims()[1]);
846 output_len.push(temp_len);
847 }
848
849 hd_transform = Some(output_len);
850 let mut image_set_tensor_inner = Vec::new();
851 for img in output_imgs {
852 let layerout = self
853 .layers
854 .forward(&img.to_device(&target_dev)?.to_dtype(target_dtype)?)?;
855 image_set_tensor_inner.push(layerout);
856 }
857 image_set_tensor = Some(Either::Left(image_set_tensor_inner));
858 }
859 } else if pixel_values.dims().len() == 4 {
860 let tt = self
861 .get_image_features(pixel_values)?
862 .to_device(&target_dev)?
863 .to_dtype(target_dtype)?
864 .reshape(((), self.image_dim_out))?;
865 let image_set_tensor_inner = self.layers.forward(&tt)?;
866 image_set_tensor = Some(Either::Right(image_set_tensor_inner));
867 } else if pixel_values.dims().len() == 3 {
868 let tt = pixel_values
869 .to_device(&target_dev)?
870 .to_dtype(target_dtype)?
871 .reshape(((), self.image_dim_out))?;
872 let image_set_tensor_inner = self.layers.forward(&tt)?;
873 image_set_tensor = Some(Either::Right(image_set_tensor_inner));
874 } else {
875 unreachable!()
876 }
877 }
878
879 let input_ids = input_ids.clamp(0.0, self.vocab_size as f64)?;
880 let mut hidden_states = self.wte.forward(&input_ids)?;
881 if select {
882 match (hd_transform, image_set_tensor) {
883 (Some(output_lens), Some(Either::Left(image_set_tensors))) => {
884 let mut idx = 0;
885 for (i, cnt) in output_lens.into_iter().enumerate() {
886 let img_set_tensor = image_set_tensors[i]
887 .to_device(&target_dev)?
888 .to_dtype(target_dtype)?;
889 let p_0 = positions.i((idx, 0))?.to_scalar::<u32>()? as usize;
891 let p_1 = positions.i((idx, 1))?.to_scalar::<u32>()? as usize;
892 hidden_states = hidden_states.slice_assign(
893 &[p_0..p_0 + 1, p_1..p_1 + cnt, 0..img_set_tensor.dims()[2]],
894 &img_set_tensor,
895 )?;
896 idx += cnt;
897 }
898 }
899 (None, Some(Either::Right(image_set_tensor))) => {
900 let mut idx = 0;
901 for i in 0..pixel_values.dim(0)? {
904 let cnt = self.num_img_tokens;
905 let img_set_tensor = image_set_tensor
906 .i(i * cnt..(i + 1) * cnt)?
907 .to_device(&target_dev)?
908 .to_dtype(target_dtype)?;
909 let p_0 = positions.i((idx, 0))?.to_scalar::<u32>()? as usize;
910 let p_1 = positions.i((idx, 1))?.to_scalar::<u32>()? as usize;
911 hidden_states = hidden_states.slice_assign(
913 &[p_0..p_0 + 1, p_1..p_1 + cnt, 0..img_set_tensor.dims()[2]],
914 &img_set_tensor,
915 )?;
916 idx += cnt;
917 }
918 }
919 _ => unreachable!(),
920 }
921 }
922
923 Ok(hidden_states)
924 }
925
926 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
927 let uvb = UnVarBuilder::new();
928
929 if let Some(glb_gn) = self.glb_gn.clone() {
930 uvb.add_tensor("glb_GN", glb_gn);
931 }
932 if let Some(sub_gn) = self.sub_gn.clone() {
933 uvb.add_tensor("sub_GN", sub_gn);
934 }
935 uvb.extend(self.tensors.clone());
936 uvb.pp("img_processor.vision_model")
937 .extend(self.image_processor.residual_tensors());
938
939 uvb.to_safetensors()
940 }
941}
942
943pub struct Model {
946 vision_embed_tokens: ImageEmbedding,
947 embed_tokens: candle_nn::Embedding,
948 layers: Vec<DecoderLayer>,
949 norm: RmsNorm,
950 lm_head: Arc<dyn QuantMethod>,
951 device: Device,
952 cache: EitherCache,
953 max_seq_len: usize,
954 mapper: Box<dyn DeviceMapper + Send + Sync>,
955 sliding_window: Option<usize>,
956 cfg: ModelConfigMetadata,
957}
958
959impl Model {
960 pub fn new(
961 cfg: &Config,
962 vb: ShardedVarBuilder,
963 _is_gptx: bool,
964 normal_loading_metadata: NormalLoadingMetadata,
965 attention_mechanism: AttentionImplementation,
966 ) -> Result<Self> {
967 let mapper = normal_loading_metadata.mapper;
968 let vb_m = vb.pp("model");
969
970 let embed_tokens = layers::embedding(
971 cfg.vocab_size,
972 cfg.hidden_size,
973 mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
974 &cfg.quantization_config,
975 )?;
976 let vision_embed_tokens = ImageEmbedding::new(
977 cfg,
978 embed_tokens.clone(),
979 &cfg.embd_layer,
980 mapper.set_nm_device(vb_m.pp("vision_embed_tokens"), false),
981 )?;
982 let vb_l = vb_m.pp("layers");
983 let mut ropes = HashMap::new();
984 for layer_idx in 0..cfg.num_hidden_layers {
985 let device = mapper
986 .device_for(layer_idx, false)
987 .unwrap_or(&normal_loading_metadata.real_device);
988 ropes.insert(
989 device.location(),
990 Arc::new(PhiRotaryEmbedding::new(vb.dtype(), cfg.clone(), device)?),
991 );
992 }
993 let layers = NiceProgressBar::<_, 'b'>(
994 0..cfg.num_hidden_layers,
995 "Loading repeating layers",
996 &normal_loading_metadata.multi_progress,
997 )
998 .par_iter_if_isq(|layer_idx| {
999 let device = mapper
1000 .device_for(layer_idx, false)
1001 .unwrap_or(&normal_loading_metadata.real_device);
1002 let rotary_emb = ropes
1003 .get(&device.location())
1004 .expect("No RoPE for device location!")
1005 .clone();
1006 let paged_attn = match &attention_mechanism {
1007 AttentionImplementation::Eager => None,
1008 AttentionImplementation::PagedAttention => {
1009 Some(PagedAttention::new(cfg.head_dim(), device, None)?)
1010 }
1011 };
1012 DecoderLayer::new(
1013 rotary_emb,
1014 cfg,
1015 vb_l.pp(layer_idx),
1016 &*mapper,
1017 layer_idx,
1018 normal_loading_metadata.loading_isq,
1019 paged_attn,
1020 )
1021 })?;
1022 let norm = RmsNorm::new(
1023 cfg.hidden_size,
1024 cfg.rms_norm_eps,
1025 mapper.set_nm_device(vb_m.pp("norm"), false),
1026 )?;
1027 let lm_head = if !cfg.tie_word_embeddings {
1028 ReplicatedLayer::new(
1029 cfg.hidden_size,
1030 cfg.vocab_size,
1031 &cfg.quantization_config,
1032 false,
1033 mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
1034 )?
1035 } else {
1036 ReplicatedLayer::from_linear(candle_nn::Linear::new(
1037 mapper.cast_nm_device(
1038 embed_tokens.embeddings(),
1039 normal_loading_metadata.loading_isq,
1040 )?,
1041 None,
1042 ))?
1043 };
1044
1045 Ok(Self {
1046 vision_embed_tokens,
1047 layers,
1048 norm,
1049 lm_head,
1050 device: normal_loading_metadata.real_device,
1051 cache: EitherCache::Normal(NormalCache::new_sliding(
1052 cfg.num_hidden_layers,
1053 cfg.max_position_embeddings,
1054 cfg.sliding_window,
1055 )),
1056 max_seq_len: cfg.max_position_embeddings,
1057 sliding_window: cfg.sliding_window,
1058 embed_tokens,
1059 cfg: ModelConfigMetadata {
1060 max_seq_len: cfg.max_position_embeddings,
1061 num_layers: cfg.num_hidden_layers,
1062 hidden_size: cfg.hidden_size,
1063 num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
1064 num_kv_heads: (cfg.num_key_value_heads / mapper.get_comm_for(0)?.world_size())
1065 .max(1),
1066 sliding_window: cfg.sliding_window,
1067 k_head_dim: cfg.head_dim(),
1068 v_head_dim: cfg.head_dim(),
1069 },
1070 mapper,
1071 })
1072 }
1073
1074 #[allow(clippy::too_many_arguments)]
1075 pub fn forward(
1076 &self,
1077 input_ids: &Tensor,
1078 pixel_values: Option<Tensor>,
1079 seqlen_offsets: &[usize],
1080 position_ids: &[usize],
1081 context_lens: Vec<(usize, usize)>,
1082 image_sizes: Option<Vec<(usize, usize)>>,
1083 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
1084 flash_params: &FlashParams,
1085 ) -> Result<Tensor> {
1086 let mut xs = if let Some(ref pixel_values) = pixel_values {
1087 self.vision_embed_tokens
1088 .forward(input_ids, pixel_values, image_sizes)?
1089 } else {
1090 self.embed_tokens.forward(input_ids)?
1091 };
1092 let cache = &mut self.cache.normal().0;
1093 let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
1094 input_ids,
1095 metadata
1096 .as_ref()
1097 .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
1098 .unwrap_or(&*cache as &dyn PastKvLenCache),
1099 self.sliding_window,
1100 xs.dtype(),
1101 self.cfg.num_attn_heads,
1102 )?;
1103 let attention_mask = attention_mask.filter(|_| {
1104 metadata
1105 .as_ref()
1106 .map(|(_, meta)| meta.is_first_prompt_chunk)
1107 .unwrap_or(true)
1108 });
1109
1110 for (i, layer) in self.layers.iter().enumerate() {
1111 xs = self.mapper.map(xs, i)?;
1112 xs = layer.forward(
1113 &xs,
1114 attention_mask
1115 .as_ref()
1116 .map(|m| m.to_device(xs.device()).unwrap())
1117 .as_ref(),
1118 seqlen_offsets,
1119 position_ids,
1120 &mut cache[i],
1121 metadata
1122 .as_ref()
1123 .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
1124 flash_params,
1125 )?
1126 }
1127 let xs = xs.to_device(&self.device)?;
1128 let mut xs = xs.apply(&self.norm)?;
1129 if let Some(t) = self.lm_head.quantized_act_type() {
1130 xs = xs.to_dtype(t)?;
1131 }
1132 extract_logits(&MatMul.qmethod_matmul(&xs, &*self.lm_head)?, context_lens)
1133 }
1134}
1135
1136impl IsqModel for Model {
1137 fn get_layers(
1138 &mut self,
1139 ) -> (
1140 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
1141 &dyn DeviceMapper,
1142 ) {
1143 let mut tensors = Vec::new();
1144 tensors.push((&mut self.lm_head, None));
1145 for (i, layer) in self.layers.iter_mut().enumerate() {
1146 tensors.push((&mut layer.self_attn.qkv_proj, Some(i)));
1147 tensors.push((&mut layer.self_attn.o_proj, Some(i)));
1148 tensors.extend(
1149 layer
1150 .mlp
1151 .get_isq_layers()
1152 .into_iter()
1153 .map(|m| (m, Some(i)))
1154 .collect::<Vec<_>>(),
1155 );
1156 }
1157 (tensors, &*self.mapper)
1158 }
1159
1160 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
1161 let uvb = UnVarBuilder::new();
1162
1163 let uvb_m = uvb.pp("model");
1164 uvb_m.pp("embed_tokens").add(&self.embed_tokens);
1165 uvb_m.pp("norm").add(&self.norm);
1166 uvb_m
1167 .pp("vision_embed_tokens")
1168 .extend(self.vision_embed_tokens.residual_tensors());
1169
1170 for (layer_idx, layer) in self.layers.iter().enumerate() {
1171 let uvb_l = uvb_m.pp("layers").pp(layer_idx);
1172 uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
1173 uvb_l
1174 .pp("post_attention_layernorm")
1175 .add(&layer.post_attention_layernorm);
1176 }
1177
1178 uvb.to_safetensors()
1179 }
1180}
1181
1182#[derive(Default)]
1183pub(crate) struct Phi3VisionSpecificArgs {
1184 pub image_sizes: Option<Vec<(usize, usize)>>,
1185}
1186
1187impl VisionModel for Model {
1188 fn forward(
1189 &self,
1190 input_ids: &Tensor,
1191 pixel_values: Option<Tensor>,
1192 seqlen_offsets: &[usize],
1193 context_lens: Vec<(usize, usize)>,
1194 position_ids: Vec<usize>,
1195 model_specific_args: Box<dyn Any>,
1196 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
1197 flash_params: &FlashParams,
1198 ) -> Result<Tensor> {
1199 let Phi3VisionSpecificArgs { image_sizes } = *model_specific_args
1200 .downcast()
1201 .expect("Cannot downcast into `Phi3VisionSpecificArgs`");
1202 self.forward(
1203 input_ids,
1204 pixel_values,
1205 seqlen_offsets,
1206 &position_ids,
1207 context_lens,
1208 image_sizes,
1209 metadata,
1210 flash_params,
1211 )
1212 }
1213 fn cache(&self) -> &EitherCache {
1214 &self.cache
1215 }
1216 fn cache_mut(&mut self) -> &mut EitherCache {
1217 &mut self.cache
1218 }
1219 fn device(&self) -> &Device {
1220 &self.device
1221 }
1222 fn max_seq_len(&self) -> usize {
1223 self.max_seq_len
1224 }
1225 fn config(&self) -> &ModelConfigMetadata {
1226 &self.cfg
1227 }
1228 fn default_model_specific_args(&self, _input_ids: &Tensor) -> Box<dyn Any> {
1229 Box::new(Phi3VisionSpecificArgs::default())
1230 }
1231}
1232
1233impl AnyMoeBaseModelMixin for Model {
1234 fn get_mlps(&self) -> Vec<&dyn MlpLayer> {
1235 let mut mlps = Vec::new();
1236 for layer in &self.layers {
1237 mlps.push(&*layer.mlp);
1238 }
1239 mlps
1240 }
1241 fn get_mlps_mut(&mut self) -> Vec<&mut Box<dyn MlpLayer>> {
1242 let mut mlps = Vec::new();
1243 for layer in &mut self.layers {
1244 mlps.push(&mut layer.mlp);
1245 }
1246 mlps
1247 }
1248 fn create_anymoe_layers(
1249 &mut self,
1250 additional_vbs: Vec<ShardedVarBuilder>,
1251 config: AnyMoeConfig,
1252 (prefix, mlp): (String, String),
1253 mut layers: Vec<usize>,
1254 expert_type: AnyMoeExpertType,
1255 gate_vb: Option<ShardedVarBuilder>,
1256 ) -> Result<()> {
1257 let mut experts: Vec<Vec<Box<dyn MlpLayer>>> = Vec::new();
1258 if layers.is_empty() {
1259 layers = (0..self.layers.len()).collect::<Vec<_>>();
1260 }
1261 for _ in 0..layers.len() {
1262 experts.push(Vec::new());
1263 }
1264 for vb in additional_vbs {
1265 let vb = vb.pp(&prefix);
1266 for (layer, row) in experts.iter_mut().enumerate() {
1267 if !layers.contains(&layer) {
1268 continue;
1269 }
1270
1271 let intermediate_size = self.layers[layer].mlp.get_params()[1];
1272 let hidden_size = self.layers[layer].mlp.get_params()[0];
1273 match expert_type {
1274 AnyMoeExpertType::FineTuned => {
1275 row.push(Box::new(Mlp::new(
1276 &Config {
1277 intermediate_size: self.layers[layer].mlp.get_params()[1],
1278 hidden_size: self.layers[layer].mlp.get_params()[0],
1279 ..Default::default()
1280 },
1281 vb.pp(layer).pp(&mlp),
1282 )?));
1283 }
1284 AnyMoeExpertType::LoraAdapter {
1285 rank,
1286 alpha,
1287 ref target_modules,
1288 } => {
1289 let vb_mlp = vb.pp(layer).pp(&mlp);
1290
1291 let gate_up_proj_delta =
1292 if target_modules.contains(&"gate_up_proj".to_string()) {
1293 Some(get_delta_from_lora_ab!(
1294 vb_mlp,
1295 rank,
1296 alpha,
1297 (hidden_size, 2 * intermediate_size),
1298 "gate_up_proj"
1299 ))
1300 } else {
1301 None
1302 };
1303 let down_proj_delta = if target_modules.contains(&"down_proj".to_string()) {
1304 Some(get_delta_from_lora_ab!(
1305 vb_mlp,
1306 rank,
1307 alpha,
1308 (hidden_size, intermediate_size),
1309 "down_proj"
1310 ))
1311 } else {
1312 None
1313 };
1314
1315 row.push(
1316 self.layers[layer]
1317 .mlp
1318 .new_added_delta(vec![gate_up_proj_delta, down_proj_delta])?,
1319 );
1320 }
1321 }
1322 }
1323 }
1324 for (layer, expert) in layers.into_iter().zip(experts) {
1325 let mut experts_all = vec![self.layers[layer].mlp.clone()];
1326 experts_all.extend(expert);
1327 let (dtype, device) = self.layers[layer].mlp.dtype_device();
1328 self.layers[layer].mlp = Box::new(MoeMlp::new(
1329 experts_all,
1330 config.clone(),
1331 dtype,
1332 &device,
1333 layer,
1334 gate_vb.as_ref(),
1335 )?);
1336 }
1337 Ok(())
1338 }
1339 fn amoe_supported(&self) -> bool {
1340 true
1341 }
1342}