1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3pub(crate) mod phi3_inputs_processor;
4
5use candle_core::{
8 shape::ShapeWithOneHole, DType, Device, IndexOp, Module, Result, Shape, Tensor, D,
9};
10use either::Either;
11use mistralrs_quant::{
12 BitWiseOp, NonZeroOp, QuantMethod, QuantizedConfig, ReplicatedLayer, ShardedVarBuilder,
13};
14use std::{any::Any, collections::HashMap, fmt::Debug, sync::Arc};
15
16use crate::{
17 amoe::{AnyMoeBaseModelMixin, AnyMoeTrainableLayer, MlpLayer, MoeMlp},
18 attention::SdpaParams,
19 device_map::DeviceMapper,
20 get_delta_from_lora_ab,
21 layers::{
22 self, Activation, CausalMasker, MatMul, PhiRopeConfig, PhiRopeScalingConfig,
23 PhiRotaryEmbedding, RmsNorm, Sdpa,
24 },
25 layers_masker::PastKvLenCache,
26 paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
27 pipeline::{
28 extract_logits,
29 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
30 EitherCache, IsqModel, KvCache, NormalCache, NormalLoadingMetadata, VisionModel,
31 },
32 serde_default_fn,
33 utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
34 vision_models::clip::{ClipConfig, ClipVisionTransformer},
35 AnyMoeConfig, AnyMoeExpertType,
36};
37
38use super::clip;
39
40#[derive(Debug, Clone, serde::Deserialize, Default)]
41pub struct EmbedLayerConfig {
42 pub hd_transform_order: Option<String>,
43 pub projection_cls: Option<String>,
44 pub use_hd_transform: Option<bool>,
45 pub with_learnable_separator: Option<bool>,
46}
47
48#[derive(Debug, Clone, serde::Deserialize, Default)]
49pub struct ImageProcessorConfig {
50 pub image_dim_out: usize,
51 pub name: String,
52 pub num_img_tokens: usize,
53 pub layer_idx: Option<isize>,
54 pub type_feature: Option<String>,
55}
56
57serde_default_fn!(bool, word_emb_default, false);
58
59#[derive(Debug, Clone, serde::Deserialize, Default)]
60pub struct Config {
61 pub vocab_size: usize,
62 pub hidden_act: Activation,
63 pub hidden_size: usize,
64 pub intermediate_size: usize,
65 pub num_hidden_layers: usize,
66 pub num_attention_heads: usize,
67 pub num_key_value_heads: usize,
68 pub rms_norm_eps: f64,
69 pub rope_theta: f64,
70 pub bos_token_id: Option<u32>,
71 pub eos_token_id: Option<u32>,
72 pub rope_scaling: Option<PhiRopeScalingConfig>,
73 pub max_position_embeddings: usize,
74 pub sliding_window: Option<usize>,
75 pub original_max_position_embeddings: usize,
76 pub embd_layer: EmbedLayerConfig,
77 pub img_processor: ImageProcessorConfig,
78 #[serde(alias = "quantization")]
79 pub quantization_config: Option<QuantizedConfig>,
80 #[serde(default = "word_emb_default")]
81 pub tie_word_embeddings: bool,
82}
83
84impl From<Config> for PhiRopeConfig {
85 fn from(val: Config) -> Self {
86 PhiRopeConfig {
87 rope_scaling: val.rope_scaling,
88 max_position_embeddings: val.max_position_embeddings,
89 original_max_position_embeddings: val.original_max_position_embeddings,
90 rope_theta: val.rope_theta,
91 head_dim: val.hidden_size / val.num_attention_heads,
92 partial_rotary_factor: None,
93 }
94 }
95}
96
97impl Config {
98 pub fn head_dim(&self) -> usize {
99 self.hidden_size / self.num_attention_heads
100 }
101}
102
103trait ModuleWithMetadata: Module + Debug + Send + Sync {
104 fn device(&self) -> Device;
105 fn dtype(&self) -> DType;
106}
107
108#[derive(Debug)]
109struct QuantMethodWrapper(Arc<dyn QuantMethod>);
110
111impl Module for QuantMethodWrapper {
112 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
113 self.0.forward(xs)
114 }
115}
116
117impl ModuleWithMetadata for QuantMethodWrapper {
118 fn device(&self) -> Device {
119 self.0.unquant_weight_bias().unwrap().0.device().clone()
120 }
121 fn dtype(&self) -> DType {
122 self.0.unquant_weight_bias().unwrap().0.dtype()
123 }
124}
125
126impl ModuleWithMetadata for candle_nn::Activation {
127 fn device(&self) -> Device {
128 unreachable!()
129 }
130 fn dtype(&self) -> DType {
131 unreachable!()
132 }
133}
134
135#[derive(Debug)]
136struct BigShapeWithOneHole((usize, usize, usize, usize, usize, ()));
137
138fn hole_size(el_count: usize, prod_d: usize, s: &dyn std::fmt::Debug) -> Result<usize> {
139 if prod_d == 0 {
140 candle_core::bail!("cannot reshape tensor of {el_count} elements to {s:?}")
141 }
142 if el_count % prod_d != 0 {
143 candle_core::bail!("cannot reshape tensor with {el_count} elements to {s:?}")
144 }
145 Ok(el_count / prod_d)
146}
147
148impl ShapeWithOneHole for BigShapeWithOneHole {
149 fn into_shape(self, el_count: usize) -> Result<Shape> {
150 let (d1, d2, d3, d4, d5, ()) = self.0;
151 let d = hole_size(el_count, d1 * d2 * d3 * d4 * d5, &self)?;
152 Ok((d1, d2, d3, d4, d5, d).into())
153 }
154}
155
156struct Attention {
159 qkv_proj: Arc<dyn QuantMethod>,
160 o_proj: Arc<dyn QuantMethod>,
161 num_heads: usize,
162 num_kv_heads: usize,
163 head_dim: usize,
164 rotary_emb: Arc<PhiRotaryEmbedding>,
165 paged_attn: Option<PagedAttention>,
166 sdpa_params: SdpaParams,
167}
168
169impl Attention {
170 fn new(
171 rotary_emb: Arc<PhiRotaryEmbedding>,
172 cfg: &Config,
173 vb: ShardedVarBuilder,
174 paged_attn: Option<PagedAttention>,
175 ) -> Result<Self> {
176 let num_heads = cfg.num_attention_heads;
177 let num_kv_heads = cfg.num_key_value_heads;
178 let head_dim = cfg.head_dim();
179 let op_size = num_heads * head_dim + 2 * num_kv_heads * head_dim;
180
181 let qkv_proj = mistralrs_quant::linear_no_bias(
183 cfg.hidden_size,
184 op_size,
185 &cfg.quantization_config,
186 vb.pp("qkv_proj"),
187 )?;
188
189 let o_proj = mistralrs_quant::linear_no_bias(
190 num_heads * head_dim,
191 cfg.hidden_size,
192 &cfg.quantization_config,
193 vb.pp("o_proj"),
194 )?;
195
196 Ok(Self {
197 qkv_proj,
198 o_proj,
199 rotary_emb,
200 num_heads,
201 num_kv_heads,
202 head_dim,
203 paged_attn,
204 sdpa_params: SdpaParams {
205 n_kv_groups: num_heads / num_kv_heads,
206 softcap: None,
207 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
208 sliding_window: cfg.sliding_window,
209 },
210 })
211 }
212
213 #[allow(clippy::too_many_arguments)]
214 fn forward(
215 &self,
216 xs: &Tensor,
217 attention_mask: Option<&Tensor>,
218 seqlen_offsets: &[usize],
219 position_ids: &[usize],
220 kv_cache: &mut KvCache,
221 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
222 flash_params: &FlashParams,
223 ) -> Result<Tensor> {
224 let (b_sz, q_len, _) = xs.dims3()?;
225
226 let original_dtype = xs.dtype();
227 let mut xs = xs.clone();
228 if let Some(t) = self.qkv_proj.quantized_act_type() {
229 xs = xs.to_dtype(t)?;
230 }
231 let mut qkv = MatMul.qmethod_matmul(&xs, &*self.qkv_proj)?;
232 if self.qkv_proj.quantized_act_type().is_some() {
233 qkv = qkv.to_dtype(original_dtype)?;
234 }
235 let query_pos = self.num_heads * self.head_dim;
236 let q = qkv.narrow(D::Minus1, 0, query_pos)?;
237 let k = qkv.narrow(D::Minus1, query_pos, self.num_kv_heads * self.head_dim)?;
238 let v = qkv.narrow(
239 D::Minus1,
240 query_pos + self.num_kv_heads * self.head_dim,
241 self.num_kv_heads * self.head_dim,
242 )?;
243
244 let (q, k, v) = if q_len != 1 {
245 let q = q
246 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
247 .transpose(1, 2)?;
248 let k = k
249 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
250 .transpose(1, 2)?;
251 let v = v
252 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
253 .transpose(1, 2)?;
254 (q, k, v)
255 } else {
256 let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
257 let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
258 let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
259 (q, k, v)
260 };
261
262 let (q, k) = self
263 .rotary_emb
264 .forward(&q, &k, seqlen_offsets, position_ids)?;
265
266 let mut attn_output = match &self.paged_attn {
267 Some(paged_attn) => match metadata {
268 Some(((key_cache, value_cache), input_metadata)) => paged_attn.forward(
269 &q,
270 &k.contiguous()?,
271 &v.contiguous()?,
272 attention_mask,
273 Some(key_cache),
274 Some(value_cache),
275 input_metadata,
276 &self.sdpa_params,
277 Some(flash_params),
278 )?,
279 None => {
280 let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
283 assert!(attention_mask.is_some());
285 paged_attn.forward(
286 &q,
287 &k.contiguous()?,
288 &v.contiguous()?,
289 attention_mask,
290 None,
291 None,
292 &input_metadata,
293 &self.sdpa_params,
294 Some(flash_params),
295 )?
296 }
297 },
298 None => {
299 let (k, v) = kv_cache.append(&k, &v)?;
300
301 Sdpa.run_attention(
302 &q,
303 &k,
304 &v,
305 attention_mask,
306 Some(flash_params),
307 &self.sdpa_params,
308 )?
309 }
310 };
311
312 if let Some(t) = self.qkv_proj.quantized_act_type() {
313 attn_output = attn_output.to_dtype(t)?;
314 }
315 attn_output = if attention_mask.is_some() {
316 attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?
317 } else {
318 attn_output.reshape((b_sz, q_len, ()))?
319 };
320 let mut res = MatMul.qmethod_matmul(&attn_output, &*self.o_proj)?;
321 if self.qkv_proj.quantized_act_type().is_some() {
322 res = res.to_dtype(original_dtype)?;
323 }
324 Ok(res)
325 }
326}
327
328#[derive(Clone)]
329struct Mlp {
330 gate_up_proj: Arc<dyn QuantMethod>,
331 down_proj: Arc<dyn QuantMethod>,
332 act_fn: Activation,
333 i_size: usize,
334 params: Vec<usize>,
335}
336
337impl Mlp {
338 fn new(cfg: &Config, vb: ShardedVarBuilder) -> Result<Self> {
339 let hidden_size = cfg.hidden_size;
340 let i_size = cfg.intermediate_size;
341
342 let gate_up_proj = mistralrs_quant::linear_no_bias(
344 hidden_size,
345 2 * i_size,
346 &cfg.quantization_config,
347 vb.pp("gate_up_proj"),
348 )?;
349
350 let down_proj = mistralrs_quant::linear_no_bias(
351 i_size,
352 hidden_size,
353 &cfg.quantization_config,
354 vb.pp("down_proj"),
355 )?;
356
357 Ok(Self {
358 gate_up_proj,
359 down_proj,
360 act_fn: cfg.hidden_act,
361 i_size,
362 params: vec![hidden_size, i_size],
363 })
364 }
365}
366
367impl AnyMoeTrainableLayer for Mlp {}
368
369impl MlpLayer for Mlp {
370 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
371 let original_dtype = xs.dtype();
372 let mut xs = xs.clone();
373 if let Some(t) = self.gate_up_proj.quantized_act_type() {
374 xs = xs.to_dtype(t)?;
375 }
376 let up_states = MatMul.qmethod_matmul(&xs, &*self.gate_up_proj)?;
377 let gate = up_states.narrow(D::Minus1, 0, self.i_size)?;
378 let up_states = up_states.narrow(D::Minus1, self.i_size, self.i_size)?;
379 let up_states = (up_states * gate.apply(&self.act_fn))?;
380 let mut res = MatMul.qmethod_matmul(&up_states, &*self.down_proj)?;
381 if self.gate_up_proj.quantized_act_type().is_some() {
382 res = res.to_dtype(original_dtype)?;
383 }
384 Ok(res)
385 }
386 fn get_isq_layers(&mut self) -> Vec<&mut Arc<dyn QuantMethod>> {
387 vec![&mut self.gate_up_proj, &mut self.down_proj]
388 }
389 fn clone(&self) -> Box<dyn MlpLayer> {
390 Box::new(Clone::clone(self))
391 }
392 fn get_params(&self) -> &[usize] {
393 &self.params
394 }
395 fn hidden_act(&self) -> Activation {
396 self.act_fn
397 }
398 fn new_added_delta(&self, deltas: Vec<Option<Tensor>>) -> Result<Box<dyn MlpLayer>> {
400 let new_gate_up = if let Some(ref delta) = deltas[0] {
401 self.gate_up_proj.add_delta_w(delta)?
402 } else {
403 self.gate_up_proj.clone()
404 };
405 let new_down = if let Some(ref delta) = deltas[1] {
406 self.down_proj.add_delta_w(delta)?
407 } else {
408 self.down_proj.clone()
409 };
410
411 Ok(Box::new(Self {
412 gate_up_proj: new_gate_up,
413 down_proj: new_down,
414 act_fn: self.act_fn,
415 i_size: self.i_size,
416 params: self.params.clone(),
417 }))
418 }
419
420 fn dtype_device(&self) -> (DType, Device) {
421 self.gate_up_proj.dtype_and_device()
422 }
423}
424
425struct DecoderLayer {
426 self_attn: Attention,
427 mlp: Box<dyn MlpLayer>,
428 input_layernorm: RmsNorm,
429 post_attention_layernorm: RmsNorm,
430}
431
432impl DecoderLayer {
433 fn new(
434 rotary_emb: Arc<PhiRotaryEmbedding>,
435 cfg: &Config,
436 vb: ShardedVarBuilder,
437 mapper: &dyn DeviceMapper,
438 layer_idx: usize,
439 loading_isq: bool,
440 paged_attn: Option<PagedAttention>,
441 ) -> Result<Self> {
442 let self_attn = Attention::new(
443 rotary_emb,
444 cfg,
445 mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
446 paged_attn,
447 )?;
448 let mlp = Mlp::new(cfg, mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq))?;
449 let input_layernorm = RmsNorm::new(
450 cfg.hidden_size,
451 cfg.rms_norm_eps,
452 mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
453 )?;
454 let post_attention_layernorm = RmsNorm::new(
455 cfg.hidden_size,
456 cfg.rms_norm_eps,
457 mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
458 )?;
459 Ok(Self {
460 self_attn,
461 mlp: Box::new(mlp),
462 input_layernorm,
463 post_attention_layernorm,
464 })
465 }
466
467 #[allow(clippy::too_many_arguments)]
468 fn forward(
469 &self,
470 xs: &Tensor,
471 attention_mask: Option<&Tensor>,
472 seqlen_offsets: &[usize],
473 position_ids: &[usize],
474 kv_cache: &mut KvCache,
475 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
476 flash_params: &FlashParams,
477 ) -> Result<Tensor> {
478 let residual = xs;
479 let xs = self.input_layernorm.forward(xs)?;
480 let xs = self
481 .self_attn
482 .forward(
483 &xs,
484 attention_mask,
485 seqlen_offsets,
486 position_ids,
487 kv_cache,
488 metadata,
489 flash_params,
490 )
491 .unwrap();
492 let xs = (xs + residual)?;
493 let residual = &xs;
494 let xs = self
495 .mlp
496 .forward(&xs.apply(&self.post_attention_layernorm)?)?;
497 residual + xs
498 }
499}
500
501const MAX_INPUT_ID: f64 = 1e9;
506
507#[derive(Debug)]
508struct EmbeddingLayers(Vec<Box<dyn ModuleWithMetadata>>);
509
510impl Module for EmbeddingLayers {
511 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
512 let mut xs = xs.clone();
513 for layer in &self.0 {
514 xs = layer.forward(&xs)?;
515 }
516 Ok(xs)
517 }
518}
519
520#[derive(Debug)]
521pub struct ImageEmbedding {
522 wte: candle_nn::Embedding,
523 image_dim_out: usize,
524 num_img_tokens: usize,
525 glb_gn: Option<Tensor>,
526 sub_gn: Option<Tensor>,
527 layers: EmbeddingLayers,
528 type_feature: String,
529 layer_idx: isize,
530 image_processor: ClipVisionTransformer,
531 hd_transform_order: String,
532 use_hd_transform: bool,
533 vocab_size: usize,
534 tensors: Vec<(String, Tensor)>,
535}
536
537pub(crate) const PHI3V_CLIP_CONFIG: ClipConfig = ClipConfig {
538 hidden_act: clip::Activation::QuickGelu,
539 hidden_size: 1024,
540 image_size: 336,
541 intermediate_size: 4096,
542 num_attention_heads: 16,
543 num_channels: 3,
544 num_hidden_layers: 24,
545 patch_size: 14,
546};
547
548impl ImageEmbedding {
549 fn new(
550 config: &Config,
551 wte: candle_nn::Embedding,
552 embed_config: &EmbedLayerConfig,
553 vb: ShardedVarBuilder,
554 ) -> Result<Self> {
555 let hidden_size = config.hidden_size;
556 if config.img_processor.name != "clip_vision_model" {
557 candle_core::bail!(
558 "img_processor=`{}` nor supported.",
559 config.img_processor.name
560 );
561 }
562 let image_dim_out = config.img_processor.image_dim_out;
563 let num_img_tokens = config.img_processor.num_img_tokens;
564
565 let image_processor =
567 ClipVisionTransformer::new(vb.pp("img_processor.vision_model"), &PHI3V_CLIP_CONFIG)?;
568
569 let use_hd_transform = embed_config.use_hd_transform.unwrap_or(false);
571 let with_learnable_separator = embed_config.with_learnable_separator.unwrap_or(false);
572 let hd_transform_order = embed_config
573 .hd_transform_order
574 .clone()
575 .unwrap_or("glb_sub".to_string());
576 assert_eq!(use_hd_transform, with_learnable_separator);
577 let (glb_gn, sub_gn) = if with_learnable_separator {
578 let glb_gn = vb.get((1, 1, image_dim_out * 4), "glb_GN")?;
579 let sub_gn = vb.get((1, 1, 1, image_dim_out * 4), "sub_GN")?;
580 (Some(glb_gn), Some(sub_gn))
581 } else {
582 (None, None)
583 };
584
585 let projection_cls = embed_config
587 .projection_cls
588 .clone()
589 .unwrap_or("linear".to_string());
590
591 let mut tensors = Vec::new();
592 let layers: Vec<Box<dyn ModuleWithMetadata>> =
593 match (projection_cls.as_str(), use_hd_transform) {
594 ("linear", _) => {
595 let a = mistralrs_quant::linear_b(
596 image_dim_out,
597 hidden_size,
598 true,
599 &None,
600 vb.pp("img_projection"),
601 )?;
602 let (a_w, a_b) = a.unquant_weight_bias().unwrap();
603 tensors.push(("img_projection.weight".to_string(), a_w));
604 if let Some(b) = a_b {
605 tensors.push(("img_projection.bias".to_string(), b));
606 }
607 vec![Box::new(QuantMethodWrapper(a))]
608 }
609 ("mlp", true) => {
610 let dim_proj = hidden_size;
611 let a = mistralrs_quant::linear_b(
612 image_dim_out * 4,
613 dim_proj,
614 true,
615 &None,
616 vb.pp("img_projection.0"),
617 )?;
618 let (a_w, a_b) = a.unquant_weight_bias().unwrap();
619 tensors.push(("img_projection.0.weight".to_string(), a_w));
620 if let Some(b) = a_b {
621 tensors.push(("img_projection.0.bias".to_string(), b));
622 }
623 let b = mistralrs_quant::linear_b(
624 dim_proj,
625 dim_proj,
626 true,
627 &None,
628 vb.pp("img_projection.2"),
629 )?;
630 let (b_w, b_b) = b.unquant_weight_bias().unwrap();
631 tensors.push(("img_projection.2.weight".to_string(), b_w));
632 if let Some(b) = b_b {
633 tensors.push(("img_projection.2.bias".to_string(), b));
634 }
635 vec![
636 Box::new(QuantMethodWrapper(a)),
637 Box::new(candle_nn::Activation::Gelu),
638 Box::new(QuantMethodWrapper(b)),
639 ]
640 }
641 ("mlp", false) => {
642 let dim_proj = hidden_size;
643 let a = mistralrs_quant::linear_b(
644 image_dim_out,
645 dim_proj,
646 true,
647 &None,
648 vb.pp("img_projection.0"),
649 )?;
650 let (a_w, a_b) = a.unquant_weight_bias().unwrap();
651 tensors.push(("img_projection.0.weight".to_string(), a_w));
652 if let Some(b) = a_b {
653 tensors.push(("img_projection.0.bias".to_string(), b));
654 }
655 let b = mistralrs_quant::linear_b(
656 dim_proj,
657 dim_proj,
658 true,
659 &None,
660 vb.pp("img_projection.2"),
661 )?;
662 let (b_w, b_b) = b.unquant_weight_bias().unwrap();
663 tensors.push(("img_projection.2.weight".to_string(), b_w));
664 if let Some(b) = b_b {
665 tensors.push(("img_projection.2.bias".to_string(), b));
666 }
667 vec![
668 Box::new(QuantMethodWrapper(a)),
669 Box::new(candle_nn::Activation::Gelu),
670 Box::new(QuantMethodWrapper(b)),
671 ]
672 }
673 _ => {
674 candle_core::bail!("projection_cls=`{projection_cls}` not implemented.");
675 }
676 };
677
678 let layer_idx = config.img_processor.layer_idx.unwrap_or(-2);
679 let type_feature = config
680 .img_processor
681 .type_feature
682 .clone()
683 .unwrap_or("patch".to_string());
684
685 Ok(Self {
686 wte,
687 image_dim_out,
688 num_img_tokens,
689 glb_gn,
690 sub_gn,
691 layer_idx,
692 type_feature,
693 image_processor,
694 layers: EmbeddingLayers(layers),
695 hd_transform_order,
696 use_hd_transform,
697 vocab_size: config.vocab_size,
698 tensors,
699 })
700 }
701
702 fn get_image_features(&self, pixel_values: &Tensor) -> Result<Tensor> {
703 let hidden_states = self
704 .image_processor
705 .forward_get_hidden_states(&pixel_values.to_dtype(self.wte.embeddings().dtype())?)?;
706 let img_feature =
707 hidden_states[(hidden_states.len() as isize + self.layer_idx) as usize].clone();
708 if self.type_feature == "patch" {
709 img_feature.i((.., 1..))
710 } else if self.type_feature == "cls_patch" {
711 Ok(img_feature)
712 } else {
713 candle_core::bail!("Unsupported image feature type {}", self.type_feature)
714 }
715 }
716
717 #[allow(non_snake_case)]
718 fn forward(
719 &self,
720 input_ids: &Tensor,
721 pixel_values: &Tensor,
722 image_sizes: Option<Vec<(usize, usize)>>,
723 ) -> Result<Tensor> {
724 let input_ids = input_ids.reshape(((), input_ids.dim(D::Minus1)?))?;
725
726 let input_ids_lt = input_ids.lt(0.0f64)?;
727 let input_ids_gt = input_ids.gt(-MAX_INPUT_ID)?;
728 let positions = input_ids_lt.bitwise_and(&input_ids_gt)?.nonzero()?;
730 let target_dev = self.layers.0[0].device();
731 let target_dtype = self.layers.0[0].dtype();
732
733 let mut select = false;
734 let mut hd_transform = None;
736 let mut image_set_tensor = None;
737 if positions.dim(0)? > 0 {
738 select = true;
739 if self.use_hd_transform && image_sizes.is_some() {
741 assert_eq!(pixel_values.dims().len(), 5);
742 let bs = pixel_values.dim(0)?;
743 let img_features = self.get_image_features(&pixel_values.flatten(0, 1)?)?;
744 let base_feat_dim = (img_features.dims()[1] as f32).sqrt() as usize;
745 assert_eq!(base_feat_dim, 24);
746
747 let img_features =
749 img_features.reshape((bs, (), base_feat_dim.pow(2), self.image_dim_out))?;
750 let C = self.image_dim_out;
751 let H = base_feat_dim;
752
753 let mut output_imgs = Vec::new();
754 let mut output_len = Vec::new();
755 for bs_ in 0..bs {
756 let (h, w) = image_sizes.as_ref().unwrap()[bs_];
757 let h = h / 336;
758 let w = w / 336;
759 let B_ = h * w;
760
761 let global_img_feature = img_features.i((bs_, ..1))?;
763
764 let glb_img = global_img_feature
766 .reshape((1, H, H, C))?
767 .reshape((1, H / 2, 2, H / 2, 2, C))?
768 .contiguous()?
769 .permute((0, 1, 3, 2, 4, 5))?
770 .reshape((1, H / 2, H / 2, 4 * C))?
771 .contiguous()?;
772 let temp_glbl_gn = self
773 .sub_gn
774 .as_ref()
775 .expect("Need `sub_gn` if `use_hd_transform`")
776 .repeat((1, H / 2, 1, 1))?;
777
778 let glb_img =
780 Tensor::cat(&[glb_img, temp_glbl_gn], 2)?.reshape((1, (), 4 * C))?;
781
782 let sub_img = img_features.i((bs_, 1..))?;
784
785 let sub_img = sub_img.i(..B_)?;
788
789 let sub_img = sub_img
791 .reshape((B_, H, H, C))?
792 .reshape((B_, H / 2, 2, H / 2, 2, C))?
793 .contiguous()?
794 .permute((0, 1, 3, 2, 4, 5))?
795 .reshape((B_, (), 4 * C))?
796 .contiguous()?;
797 let sub_img = sub_img
798 .reshape(BigShapeWithOneHole((1usize, h, w, 12usize, 12usize, ())))?
799 .permute((0, 1, 3, 2, 4, 5))?
800 .reshape((1, h * 12, w * 12, 4 * C))?;
801 let temp_sub_gn = self
802 .sub_gn
803 .as_ref()
804 .expect("Need `sub_gn` if `use_hd_transform`")
805 .repeat((1, h * 12, 1, 1))?;
806
807 let sub_img =
808 Tensor::cat(&[sub_img, temp_sub_gn], 2)?.reshape((1, (), 4 * C))?;
809
810 match self.hd_transform_order.as_str() {
813 "glb_sub" => {
814 output_imgs.push(Tensor::cat(
815 &[
816 glb_img,
817 self.glb_gn
818 .as_ref()
819 .expect("Need `glb_gn` if `use_hd_transform`")
820 .clone(),
821 sub_img,
822 ],
823 1,
824 )?);
825 }
826 "sub_glb" => {
827 output_imgs.push(Tensor::cat(
828 &[
829 sub_img,
830 self.glb_gn
831 .as_ref()
832 .expect("Need `glb_gn` if `use_hd_transform`")
833 .clone(),
834 glb_img,
835 ],
836 1,
837 )?);
838 }
839 other => {
840 candle_core::bail!("Invalid hd_transform_order=`{other}`");
841 }
842 }
843
844 let temp_len = (h * w + 1) * 144 + 1 + (h + 1) * 12;
845 assert_eq!(temp_len, output_imgs.last().unwrap().dims()[1]);
846 output_len.push(temp_len);
847 }
848
849 hd_transform = Some(output_len);
850 let mut image_set_tensor_inner = Vec::new();
851 for img in output_imgs {
852 let layerout = self
853 .layers
854 .forward(&img.to_device(&target_dev)?.to_dtype(target_dtype)?)?;
855 image_set_tensor_inner.push(layerout);
856 }
857 image_set_tensor = Some(Either::Left(image_set_tensor_inner));
858 } else if pixel_values.dims().len() == 4 {
859 let tt = self
860 .get_image_features(pixel_values)?
861 .to_device(&target_dev)?
862 .to_dtype(target_dtype)?
863 .reshape(((), self.image_dim_out))?;
864 let image_set_tensor_inner = self.layers.forward(&tt)?;
865 image_set_tensor = Some(Either::Right(image_set_tensor_inner));
866 } else if pixel_values.dims().len() == 3 {
867 let tt = pixel_values
868 .to_device(&target_dev)?
869 .to_dtype(target_dtype)?
870 .reshape(((), self.image_dim_out))?;
871 let image_set_tensor_inner = self.layers.forward(&tt)?;
872 image_set_tensor = Some(Either::Right(image_set_tensor_inner));
873 } else {
874 unreachable!()
875 }
876 }
877
878 let input_ids = input_ids.clamp(0.0, self.vocab_size as f64)?;
879 let mut hidden_states = self.wte.forward(&input_ids)?;
880 if select {
881 match (hd_transform, image_set_tensor) {
882 (Some(output_lens), Some(Either::Left(image_set_tensors))) => {
883 let mut idx = 0;
884 for (i, cnt) in output_lens.into_iter().enumerate() {
885 let img_set_tensor = image_set_tensors[i]
886 .to_device(&target_dev)?
887 .to_dtype(target_dtype)?;
888 let p_0 = positions.i((idx, 0))?.to_scalar::<u32>()? as usize;
890 let p_1 = positions.i((idx, 1))?.to_scalar::<u32>()? as usize;
891 hidden_states = hidden_states.slice_assign(
892 &[&p_0, &(p_1..p_1 + cnt), &(..img_set_tensor.dims()[2])],
893 &img_set_tensor,
894 )?;
895 idx += cnt;
896 }
897 }
898 (None, Some(Either::Right(image_set_tensor))) => {
899 let mut idx = 0;
900 for i in 0..pixel_values.dim(0)? {
903 let cnt = self.num_img_tokens;
904 let img_set_tensor = image_set_tensor
905 .i(i * cnt..(i + 1) * cnt)?
906 .to_device(&target_dev)?
907 .to_dtype(target_dtype)?;
908 let p_0 = positions.i((idx, 0))?.to_scalar::<u32>()? as usize;
909 let p_1 = positions.i((idx, 1))?.to_scalar::<u32>()? as usize;
910 hidden_states = hidden_states.slice_assign(
912 &[&p_0, &(p_1..p_1 + cnt), &(..img_set_tensor.dims()[2])],
913 &img_set_tensor,
914 )?;
915 idx += cnt;
916 }
917 }
918 _ => unreachable!(),
919 }
920 }
921
922 Ok(hidden_states)
923 }
924
925 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
926 let uvb = UnVarBuilder::new();
927
928 if let Some(glb_gn) = self.glb_gn.clone() {
929 uvb.add_tensor("glb_GN", glb_gn);
930 }
931 if let Some(sub_gn) = self.sub_gn.clone() {
932 uvb.add_tensor("sub_GN", sub_gn);
933 }
934 uvb.extend(self.tensors.clone());
935 uvb.pp("img_processor.vision_model")
936 .extend(self.image_processor.residual_tensors());
937
938 uvb.to_safetensors()
939 }
940}
941
942pub struct Model {
945 vision_embed_tokens: ImageEmbedding,
946 embed_tokens: candle_nn::Embedding,
947 layers: Vec<DecoderLayer>,
948 norm: RmsNorm,
949 lm_head: Arc<dyn QuantMethod>,
950 device: Device,
951 cache: EitherCache,
952 max_seq_len: usize,
953 mapper: Box<dyn DeviceMapper + Send + Sync>,
954 sliding_window: Option<usize>,
955 cfg: ModelConfigMetadata,
956}
957
958impl Model {
959 pub fn new(
960 cfg: &Config,
961 vb: ShardedVarBuilder,
962 _is_gptx: bool,
963 normal_loading_metadata: NormalLoadingMetadata,
964 attention_mechanism: AttentionImplementation,
965 ) -> Result<Self> {
966 let mapper = normal_loading_metadata.mapper;
967 let vb_m = vb.pp("model");
968
969 let embed_tokens = layers::embedding(
970 cfg.vocab_size,
971 cfg.hidden_size,
972 mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
973 &cfg.quantization_config,
974 )?;
975 let vision_embed_tokens = ImageEmbedding::new(
976 cfg,
977 embed_tokens.clone(),
978 &cfg.embd_layer,
979 mapper.set_nm_device(vb_m.pp("vision_embed_tokens"), false),
980 )?;
981 let vb_l = vb_m.pp("layers");
982 let mut ropes = HashMap::new();
983 for layer_idx in 0..cfg.num_hidden_layers {
984 let device = mapper
985 .device_for(layer_idx, false)
986 .unwrap_or(&normal_loading_metadata.real_device);
987 ropes.insert(
988 device.location(),
989 Arc::new(PhiRotaryEmbedding::new(vb.dtype(), cfg.clone(), device)?),
990 );
991 }
992 let layers = NiceProgressBar::<_, 'b'>(
993 0..cfg.num_hidden_layers,
994 "Loading repeating layers",
995 &normal_loading_metadata.multi_progress,
996 )
997 .par_iter_if_isq(|layer_idx| {
998 let device = mapper
999 .device_for(layer_idx, false)
1000 .unwrap_or(&normal_loading_metadata.real_device);
1001 let rotary_emb = ropes
1002 .get(&device.location())
1003 .expect("No RoPE for device location!")
1004 .clone();
1005 let paged_attn = match &attention_mechanism {
1006 AttentionImplementation::Eager => None,
1007 AttentionImplementation::PagedAttention => {
1008 Some(PagedAttention::new(cfg.head_dim(), device, None)?)
1009 }
1010 };
1011 DecoderLayer::new(
1012 rotary_emb,
1013 cfg,
1014 vb_l.pp(layer_idx),
1015 &*mapper,
1016 layer_idx,
1017 normal_loading_metadata.loading_isq,
1018 paged_attn,
1019 )
1020 })?;
1021 let norm = RmsNorm::new(
1022 cfg.hidden_size,
1023 cfg.rms_norm_eps,
1024 mapper.set_nm_device(vb_m.pp("norm"), false),
1025 )?;
1026 let lm_head = if !cfg.tie_word_embeddings {
1027 ReplicatedLayer::new(
1028 cfg.hidden_size,
1029 cfg.vocab_size,
1030 &cfg.quantization_config,
1031 false,
1032 mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
1033 )?
1034 } else {
1035 ReplicatedLayer::from_linear(candle_nn::Linear::new(
1036 mapper.cast_nm_device(
1037 embed_tokens.embeddings(),
1038 normal_loading_metadata.loading_isq,
1039 )?,
1040 None,
1041 ))?
1042 };
1043
1044 Ok(Self {
1045 vision_embed_tokens,
1046 layers,
1047 norm,
1048 lm_head,
1049 device: normal_loading_metadata.real_device,
1050 cache: EitherCache::Normal(NormalCache::new_sliding(
1051 cfg.num_hidden_layers,
1052 cfg.max_position_embeddings,
1053 cfg.sliding_window,
1054 )),
1055 max_seq_len: cfg.max_position_embeddings,
1056 sliding_window: cfg.sliding_window,
1057 embed_tokens,
1058 cfg: ModelConfigMetadata {
1059 max_seq_len: cfg.max_position_embeddings,
1060 num_layers: cfg.num_hidden_layers,
1061 hidden_size: cfg.hidden_size,
1062 num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
1063 num_kv_heads: (cfg.num_key_value_heads / mapper.get_comm_for(0)?.world_size())
1064 .max(1),
1065 sliding_window: cfg.sliding_window,
1066 k_head_dim: cfg.head_dim(),
1067 v_head_dim: cfg.head_dim(),
1068 },
1069 mapper,
1070 })
1071 }
1072
1073 #[allow(clippy::too_many_arguments)]
1074 pub fn forward(
1075 &self,
1076 input_ids: &Tensor,
1077 pixel_values: Option<Tensor>,
1078 seqlen_offsets: &[usize],
1079 position_ids: &[usize],
1080 context_lens: Vec<(usize, usize)>,
1081 image_sizes: Option<Vec<(usize, usize)>>,
1082 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
1083 flash_params: &FlashParams,
1084 ) -> Result<Tensor> {
1085 let mut xs = if let Some(ref pixel_values) = pixel_values {
1086 self.vision_embed_tokens
1087 .forward(input_ids, pixel_values, image_sizes)?
1088 } else {
1089 self.embed_tokens.forward(input_ids)?
1090 };
1091 let cache = &mut self.cache.normal().0;
1092 let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
1093 input_ids,
1094 metadata
1095 .as_ref()
1096 .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
1097 .unwrap_or(&*cache as &dyn PastKvLenCache),
1098 self.sliding_window,
1099 xs.dtype(),
1100 self.cfg.num_attn_heads,
1101 )?;
1102 let attention_mask = attention_mask.filter(|_| {
1103 metadata
1104 .as_ref()
1105 .map(|(_, meta)| meta.is_first_prompt_chunk)
1106 .unwrap_or(true)
1107 });
1108
1109 for (i, layer) in self.layers.iter().enumerate() {
1110 xs = self.mapper.map(xs, i)?;
1111 xs = layer.forward(
1112 &xs,
1113 attention_mask
1114 .as_ref()
1115 .map(|m| m.to_device(xs.device()).unwrap())
1116 .as_ref(),
1117 seqlen_offsets,
1118 position_ids,
1119 &mut cache[i],
1120 metadata
1121 .as_ref()
1122 .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
1123 flash_params,
1124 )?
1125 }
1126 let xs = xs.to_device(&self.device)?;
1127 let mut xs = xs.apply(&self.norm)?;
1128 if let Some(t) = self.lm_head.quantized_act_type() {
1129 xs = xs.to_dtype(t)?;
1130 }
1131 extract_logits(&MatMul.qmethod_matmul(&xs, &*self.lm_head)?, context_lens)
1132 }
1133}
1134
1135impl IsqModel for Model {
1136 fn get_layers(
1137 &mut self,
1138 ) -> (
1139 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
1140 &dyn DeviceMapper,
1141 ) {
1142 let mut tensors = Vec::new();
1143 tensors.push((&mut self.lm_head, None));
1144 for (i, layer) in self.layers.iter_mut().enumerate() {
1145 tensors.push((&mut layer.self_attn.qkv_proj, Some(i)));
1146 tensors.push((&mut layer.self_attn.o_proj, Some(i)));
1147 tensors.extend(
1148 layer
1149 .mlp
1150 .get_isq_layers()
1151 .into_iter()
1152 .map(|m| (m, Some(i)))
1153 .collect::<Vec<_>>(),
1154 );
1155 }
1156 (tensors, &*self.mapper)
1157 }
1158
1159 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
1160 let uvb = UnVarBuilder::new();
1161
1162 let uvb_m = uvb.pp("model");
1163 uvb_m.pp("embed_tokens").add(&self.embed_tokens);
1164 uvb_m.pp("norm").add(&self.norm);
1165 uvb_m
1166 .pp("vision_embed_tokens")
1167 .extend(self.vision_embed_tokens.residual_tensors());
1168
1169 for (layer_idx, layer) in self.layers.iter().enumerate() {
1170 let uvb_l = uvb_m.pp("layers").pp(layer_idx);
1171 uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
1172 uvb_l
1173 .pp("post_attention_layernorm")
1174 .add(&layer.post_attention_layernorm);
1175 }
1176
1177 uvb.to_safetensors()
1178 }
1179}
1180
1181#[derive(Default)]
1182pub(crate) struct Phi3VisionSpecificArgs {
1183 pub image_sizes: Option<Vec<(usize, usize)>>,
1184}
1185
1186impl VisionModel for Model {
1187 fn forward(
1188 &self,
1189 input_ids: &Tensor,
1190 pixel_values: Option<Tensor>,
1191 seqlen_offsets: &[usize],
1192 context_lens: Vec<(usize, usize)>,
1193 position_ids: Vec<usize>,
1194 model_specific_args: Box<dyn Any>,
1195 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
1196 flash_params: &FlashParams,
1197 ) -> Result<Tensor> {
1198 let Phi3VisionSpecificArgs { image_sizes } = *model_specific_args
1199 .downcast()
1200 .expect("Cannot downcast into `Phi3VisionSpecificArgs`");
1201 self.forward(
1202 input_ids,
1203 pixel_values,
1204 seqlen_offsets,
1205 &position_ids,
1206 context_lens,
1207 image_sizes,
1208 metadata,
1209 flash_params,
1210 )
1211 }
1212 fn cache(&self) -> &EitherCache {
1213 &self.cache
1214 }
1215 fn cache_mut(&mut self) -> &mut EitherCache {
1216 &mut self.cache
1217 }
1218 fn device(&self) -> &Device {
1219 &self.device
1220 }
1221 fn max_seq_len(&self) -> usize {
1222 self.max_seq_len
1223 }
1224 fn config(&self) -> &ModelConfigMetadata {
1225 &self.cfg
1226 }
1227 fn default_model_specific_args(&self, _input_ids: &Tensor) -> Box<dyn Any> {
1228 Box::new(Phi3VisionSpecificArgs::default())
1229 }
1230}
1231
1232impl AnyMoeBaseModelMixin for Model {
1233 fn get_mlps(&self) -> Vec<&dyn MlpLayer> {
1234 let mut mlps = Vec::new();
1235 for layer in &self.layers {
1236 mlps.push(&*layer.mlp);
1237 }
1238 mlps
1239 }
1240 fn get_mlps_mut(&mut self) -> Vec<&mut Box<dyn MlpLayer>> {
1241 let mut mlps = Vec::new();
1242 for layer in &mut self.layers {
1243 mlps.push(&mut layer.mlp);
1244 }
1245 mlps
1246 }
1247 fn create_anymoe_layers(
1248 &mut self,
1249 additional_vbs: Vec<ShardedVarBuilder>,
1250 config: AnyMoeConfig,
1251 (prefix, mlp): (String, String),
1252 mut layers: Vec<usize>,
1253 expert_type: AnyMoeExpertType,
1254 gate_vb: Option<ShardedVarBuilder>,
1255 ) -> Result<()> {
1256 let mut experts: Vec<Vec<Box<dyn MlpLayer>>> = Vec::new();
1257 if layers.is_empty() {
1258 layers = (0..self.layers.len()).collect::<Vec<_>>();
1259 }
1260 for _ in 0..layers.len() {
1261 experts.push(Vec::new());
1262 }
1263 for vb in additional_vbs {
1264 let vb = vb.pp(&prefix);
1265 for (layer, row) in experts.iter_mut().enumerate() {
1266 if !layers.contains(&layer) {
1267 continue;
1268 }
1269
1270 let intermediate_size = self.layers[layer].mlp.get_params()[1];
1271 let hidden_size = self.layers[layer].mlp.get_params()[0];
1272 match expert_type {
1273 AnyMoeExpertType::FineTuned => {
1274 row.push(Box::new(Mlp::new(
1275 &Config {
1276 intermediate_size: self.layers[layer].mlp.get_params()[1],
1277 hidden_size: self.layers[layer].mlp.get_params()[0],
1278 ..Default::default()
1279 },
1280 vb.pp(layer).pp(&mlp),
1281 )?));
1282 }
1283 AnyMoeExpertType::LoraAdapter {
1284 rank,
1285 alpha,
1286 ref target_modules,
1287 } => {
1288 let vb_mlp = vb.pp(layer).pp(&mlp);
1289
1290 let gate_up_proj_delta =
1291 if target_modules.contains(&"gate_up_proj".to_string()) {
1292 Some(get_delta_from_lora_ab!(
1293 vb_mlp,
1294 rank,
1295 alpha,
1296 (hidden_size, 2 * intermediate_size),
1297 "gate_up_proj"
1298 ))
1299 } else {
1300 None
1301 };
1302 let down_proj_delta = if target_modules.contains(&"down_proj".to_string()) {
1303 Some(get_delta_from_lora_ab!(
1304 vb_mlp,
1305 rank,
1306 alpha,
1307 (hidden_size, intermediate_size),
1308 "down_proj"
1309 ))
1310 } else {
1311 None
1312 };
1313
1314 row.push(
1315 self.layers[layer]
1316 .mlp
1317 .new_added_delta(vec![gate_up_proj_delta, down_proj_delta])?,
1318 );
1319 }
1320 }
1321 }
1322 }
1323 for (layer, expert) in layers.into_iter().zip(experts) {
1324 let mut experts_all = vec![self.layers[layer].mlp.clone()];
1325 experts_all.extend(expert);
1326 let (dtype, device) = self.layers[layer].mlp.dtype_device();
1327 self.layers[layer].mlp = Box::new(MoeMlp::new(
1328 experts_all,
1329 config.clone(),
1330 dtype,
1331 &device,
1332 layer,
1333 gate_vb.as_ref(),
1334 )?);
1335 }
1336 Ok(())
1337 }
1338 fn amoe_supported(&self) -> bool {
1339 true
1340 }
1341}