1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{f32::consts::PI, ops::Mul, str::FromStr, sync::Arc};
4
5use candle_core::{
6 quantized::{QMatMul, QTensor},
7 Context, DType, Device, IndexOp, Result, Tensor, D,
8};
9use candle_nn::{
10 Conv2d, Conv2dConfig, Embedding, GroupNorm, LayerNorm, LayerNormConfig, Linear, Module,
11};
12use float8::F8E4M3;
13use half::{bf16, f16};
14use mistralrs_quant::{
15 AfqLayer, ColumnParallelLayer, QuantMethod, QuantizedConfig, RowParallelLayer,
16 ShardedVarBuilder,
17};
18use serde::{Deserialize, Serialize};
19
20pub use crate::attention::Sdpa;
21pub use crate::layers_masker::CausalMasker;
22pub use crate::layers_utils::repeat_kv;
23use crate::{
24 amoe::{AnyMoeTrainableLayer, MlpLayer},
25 gguf::Content,
26 models::llama,
27 ops::SplitOp,
28 vision_models::{
29 gemma3::config::Gemma3TextConfig,
30 llama4,
31 mllama::{MLlamaRopeScaling, MLlamaRopeType, MLlamaTextConfig},
32 phi4::Phi4MMConfig,
33 },
34};
35
36pub use mistralrs_quant::MatMul;
37
38pub fn embedding(
39 in_size: usize,
40 out_size: usize,
41 vb: ShardedVarBuilder,
42 config: &Option<QuantizedConfig>,
43) -> Result<Embedding> {
44 let embeddings = if let Some(QuantizedConfig::Afq { .. }) = config {
46 let afq_layer =
47 AfqLayer::afq_linear_b(out_size, in_size, config.as_ref().unwrap(), false, vb)?;
48 afq_layer.dequantize_w()?
49 } else {
50 vb.get_with_hints((in_size, out_size), "weight", Default::default())?
51 };
52 Ok(Embedding::new(embeddings, out_size))
53}
54
55pub fn layer_norm<C: Into<LayerNormConfig>>(
56 size: usize,
57 config: C,
58 vb: ShardedVarBuilder,
59) -> Result<LayerNorm> {
60 let config = config.into();
61 let weight = vb.get(size, "weight")?;
62 if config.affine {
63 let bias = vb.get(size, "bias")?;
64 Ok(LayerNorm::new(weight, bias, config.eps))
65 } else {
66 Ok(LayerNorm::new_no_bias(weight, config.eps))
67 }
68}
69
70pub fn group_norm(
71 num_groups: usize,
72 num_channels: usize,
73 eps: f64,
74 vb: ShardedVarBuilder,
75) -> Result<GroupNorm> {
76 let weight = vb.get(num_channels, "weight")?;
77 let bias = vb.get(num_channels, "bias")?;
78 GroupNorm::new(weight, bias, num_channels, num_groups, eps)
79}
80
81pub fn conv2d(
82 in_channels: usize,
83 out_channels: usize,
84 kernel_size: usize,
85 cfg: Conv2dConfig,
86 vb: ShardedVarBuilder,
87) -> Result<Conv2d> {
88 let ws = vb.get(
89 (
90 out_channels,
91 in_channels / cfg.groups,
92 kernel_size,
93 kernel_size,
94 ),
95 "weight",
96 )?;
97 let bs = vb.get(out_channels, "bias")?;
98 Ok(Conv2d::new(ws, Some(bs), cfg))
99}
100
101pub fn conv2d_no_bias(
102 in_channels: usize,
103 out_channels: usize,
104 kernel_size: usize,
105 cfg: Conv2dConfig,
106 vb: ShardedVarBuilder,
107) -> Result<Conv2d> {
108 let ws = vb.get(
109 (
110 out_channels,
111 in_channels / cfg.groups,
112 kernel_size,
113 kernel_size,
114 ),
115 "weight",
116 )?;
117 Ok(Conv2d::new(ws, None, cfg))
118}
119
120pub fn linear(in_dim: usize, out_dim: usize, vb: ShardedVarBuilder) -> Result<Linear> {
121 let ws = vb.get((out_dim, in_dim), "weight")?;
122 let bs = vb.get(out_dim, "bias")?;
123 Ok(Linear::new(ws, Some(bs)))
124}
125
126pub fn linear_no_bias(in_dim: usize, out_dim: usize, vb: ShardedVarBuilder) -> Result<Linear> {
127 let ws = vb.get((out_dim, in_dim), "weight")?;
128 Ok(Linear::new(ws, None))
129}
130
131pub fn linear_b(
132 in_dim: usize,
133 out_dim: usize,
134 bias: bool,
135 vb: ShardedVarBuilder,
136) -> Result<Linear> {
137 if bias {
138 linear(in_dim, out_dim, vb)
139 } else {
140 linear_no_bias(in_dim, out_dim, vb)
141 }
142}
143
144#[derive(Debug, Clone)]
145pub struct RmsNorm {
146 eps: f64,
147 weight: Tensor,
148}
149
150impl RmsNorm {
151 pub fn new(size: usize, eps: f64, vb: ShardedVarBuilder) -> Result<Self> {
152 let w = vb.get(size, "weight")?;
153 Ok(Self { eps, weight: w })
154 }
155
156 pub fn new_gemma(size: usize, eps: f64, vb: ShardedVarBuilder) -> Result<Self> {
158 let w = vb.get(size, "weight")?;
159 let w = (w + 1.0)?;
160 Ok(Self { eps, weight: w })
161 }
162
163 pub fn undo_gemma(&self) -> Result<Self> {
165 Ok(Self {
166 eps: self.eps,
167 weight: (&self.weight - 1.0)?,
168 })
169 }
170
171 pub fn from_w(w: Tensor, eps: f64) -> Result<Self> {
172 Ok(Self { eps, weight: w })
173 }
174
175 pub fn weight(&self) -> &Tensor {
176 &self.weight
177 }
178}
179
180impl Module for RmsNorm {
181 fn forward(&self, x: &Tensor) -> Result<Tensor> {
182 candle_nn::ops::rms_norm(&x.contiguous()?, &self.weight, self.eps as f32)
183 }
184}
185
186#[derive(Debug, Clone)]
187pub struct F32RmsNorm {
188 w: Tensor,
189 eps: f64,
190}
191
192impl F32RmsNorm {
193 pub fn new(size: usize, eps: f64, vb: ShardedVarBuilder) -> Result<Self> {
194 Ok(Self {
195 w: vb.get((size,), "weight")?,
196 eps,
197 })
198 }
199
200 pub fn weight(&self) -> &Tensor {
201 &self.w
202 }
203}
204
205impl Module for F32RmsNorm {
206 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
207 let initial_type = xs.dtype();
208 let mut xs = xs.to_dtype(DType::F32)?;
209 let var = xs.powf(2.)?.mean_keepdim(D::Minus1)?;
210 xs = xs.broadcast_mul(&(&var + self.eps)?.recip()?.sqrt()?)?;
211 xs.to_dtype(initial_type)?.broadcast_mul(&self.w)
212 }
213}
214
215#[derive(Debug, Clone)]
216pub struct QRmsNorm {
217 eps: f64,
218 weight: Tensor,
219}
220
221impl QRmsNorm {
222 pub fn new(scale: QTensor, eps: f32) -> Result<Self> {
223 let scale = scale.dequantize(&scale.device())?;
224 Ok(Self {
225 eps: eps as f64,
226 weight: scale,
227 })
228 }
229
230 pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
231 candle_nn::ops::rms_norm(&x.contiguous()?, &self.weight, self.eps as f32)
232 }
233}
234
235#[derive(Debug, Clone)]
237pub struct PhiRotaryEmbedding {
238 short_sin: Tensor,
239 short_cos: Tensor,
240 long_cos: Option<Tensor>,
241 long_sin: Option<Tensor>,
242 original_max_position_embeddings: usize,
243}
244
245#[derive(Debug, Clone, Deserialize, Serialize)]
246#[serde(rename_all = "lowercase")]
247pub enum ScaledRopeType {
248 #[serde(alias = "su")]
249 #[serde(alias = "longrope")]
250 Su,
251 #[serde(alias = "yarn")]
252 Yarn,
253 #[serde(alias = "dynamic")]
254 Dynamic,
255 #[serde(alias = "linear")]
256 Linear,
257}
258
259impl FromStr for ScaledRopeType {
260 type Err = candle_core::Error;
261 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
262 match s {
263 "su" | "longrope" => Ok(Self::Su),
264 "yarn" => Ok(Self::Yarn),
265 "linear" => Ok(Self::Linear),
266 "dynamic" => Ok(Self::Dynamic),
267 _ => Err(candle_core::Error::Msg(
268 "Expected either `su` or `yarn` scaled RoPE type.".to_string(),
269 )),
270 }
271 }
272}
273
274#[derive(Debug, Clone, Deserialize, Serialize)]
275#[serde(untagged)]
276pub enum PhiRopeScalingConfig {
277 Classic {
278 short_factor: Vec<f64>,
279 long_factor: Vec<f64>,
280 #[serde(rename = "type")]
281 scaling_type: ScaledRopeType,
282 },
283 Scaled {
284 short_factor: Vec<f64>,
285 long_factor: Vec<f64>,
286 #[serde(rename = "type")]
287 scaling_type: ScaledRopeType,
288 long_mscale: f64,
289 short_mscale: f64,
290 },
291}
292
293pub struct PhiRopeConfig {
294 pub rope_scaling: Option<PhiRopeScalingConfig>,
295 pub max_position_embeddings: usize,
296 pub original_max_position_embeddings: usize,
297 pub rope_theta: f64,
298 pub head_dim: usize,
299 pub partial_rotary_factor: Option<f64>,
300}
301
302impl PhiRotaryEmbedding {
303 fn new_classic_scaled(
304 short_factor: &[f64],
305 long_factor: &[f64],
306 scaling_type: &ScaledRopeType,
307 cfg: &PhiRopeConfig,
308 dtype: DType,
309 dev: &Device,
310 ) -> Result<Self> {
311 let max_seq_len = cfg.max_position_embeddings;
312 let dim = (cfg.head_dim as f64 * cfg.partial_rotary_factor.unwrap_or(1.)) as usize;
313
314 let scale =
316 cfg.max_position_embeddings as f64 / cfg.original_max_position_embeddings as f64;
317 let scaling_factor = if scale <= 1.0 {
318 1.0
319 } else {
320 match scaling_type {
321 ScaledRopeType::Su => {
322 (1.0 + scale.ln() / (cfg.original_max_position_embeddings as f64).ln()).sqrt()
323 }
324 ScaledRopeType::Yarn => 0.1 * scale.ln() + 1.0,
325 _ => candle_core::bail!("Expected either `su` or `yarn` RoPE"),
326 }
327 };
328
329 let inv_freq_long = (0..dim)
331 .step_by(2)
332 .enumerate()
333 .map(|(k, i)| {
334 (1f64 / (long_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64))) as f32
335 })
336 .collect::<Vec<_>>();
337 let inv_freq_short = (0..dim)
338 .step_by(2)
339 .enumerate()
340 .map(|(k, i)| {
341 (1f64 / (short_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64))) as f32
342 })
343 .collect::<Vec<_>>();
344 let inv_freq_len = inv_freq_long.len();
345
346 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
347 .to_dtype(DType::F32)?
348 .reshape((max_seq_len, 1))?;
349
350 let inv_freq_long = Tensor::from_vec(inv_freq_long, (1, inv_freq_len), dev)?;
352 let freqs_long = t.matmul(&inv_freq_long)?;
353 let long_sin = freqs_long.sin()?.mul(scaling_factor)?.to_dtype(dtype)?;
354 let long_cos = freqs_long.cos()?.mul(scaling_factor)?.to_dtype(dtype)?;
355
356 let inv_freq_short =
358 Tensor::from_vec(inv_freq_short, (1, inv_freq_len), dev)?.to_dtype(DType::F32)?;
359 let freqs_short = t.matmul(&inv_freq_short)?;
360 let short_sin = freqs_short.sin()?.mul(scaling_factor)?.to_dtype(dtype)?;
361 let short_cos = freqs_short.cos()?.mul(scaling_factor)?.to_dtype(dtype)?;
362
363 Ok(Self {
364 short_cos,
365 short_sin,
366 long_cos: Some(long_cos),
367 long_sin: Some(long_sin),
368 original_max_position_embeddings: cfg.original_max_position_embeddings,
369 })
370 }
371
372 fn new_unscaled(cfg: &PhiRopeConfig, dtype: DType, dev: &Device) -> Result<Self> {
373 let max_seq_len = cfg.max_position_embeddings;
374 let dim = (cfg.head_dim as f64 * cfg.partial_rotary_factor.unwrap_or(1.)) as usize;
375
376 let inv_freq: Vec<_> = (0..dim)
377 .step_by(2)
378 .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
379 .collect();
380 let inv_freq_len = inv_freq.len();
381 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
382 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
383 .to_dtype(DType::F32)?
384 .reshape((max_seq_len, 1))?;
385 let freqs = t.matmul(&inv_freq)?;
386 let sin = freqs.sin()?.to_dtype(dtype)?;
387 let cos = freqs.cos()?.to_dtype(dtype)?;
388 Ok(Self {
389 short_cos: cos,
390 short_sin: sin,
391 long_cos: None,
392 long_sin: None,
393 original_max_position_embeddings: cfg.original_max_position_embeddings,
394 })
395 }
396
397 #[allow(clippy::too_many_arguments)]
398 fn new_scaled(
399 short_factor: &[f64],
400 long_factor: &[f64],
401 scaling_type: &ScaledRopeType,
402 long_mscale: f64,
403 short_mscale: f64,
404 cfg: &PhiRopeConfig,
405 dtype: DType,
406 dev: &Device,
407 ) -> Result<Self> {
408 let max_seq_len = cfg.max_position_embeddings;
409 let dim = (cfg.head_dim as f64 * cfg.partial_rotary_factor.unwrap_or(1.)) as usize;
410
411 if !matches!(scaling_type, ScaledRopeType::Su) {
412 candle_core::bail!("Scaled Phi3 RoPE (non-classic scaled, with mscales) must have type `su`/`longrope`.");
413 }
414
415 if short_factor.len() != dim / 2 {
416 candle_core::bail!(
417 "Misaligned length {}, expected {} for `su`/`longrope` short rescale factors",
418 short_factor.len(),
419 dim / 2
420 );
421 }
422 if long_factor.len() != dim / 2 {
423 candle_core::bail!(
424 "Misaligned length {}, expected {} for `su`/`longrope` long rescale factors",
425 long_factor.len(),
426 dim / 2
427 );
428 }
429
430 let inv_freq_short: Vec<_> = (0..dim)
432 .step_by(2)
433 .enumerate()
434 .map(|(k, i)| {
435 1f32 / (short_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32
436 })
437 .collect();
438 let inv_freq_len_short = inv_freq_short.len();
439 let inv_freq_short = Tensor::from_vec(inv_freq_short, (1, inv_freq_len_short), dev)?;
440 let t_short = Tensor::arange(0u32, max_seq_len as u32, dev)?
441 .to_dtype(DType::F32)?
442 .reshape((max_seq_len, 1))?;
443 let freqs_short = t_short.matmul(&inv_freq_short)?;
444 let sin_short = (freqs_short.sin()?.to_dtype(dtype)? * short_mscale)?;
445 let cos_short = (freqs_short.cos()?.to_dtype(dtype)? * short_mscale)?;
446
447 let inv_freq_long: Vec<_> = (0..dim)
449 .step_by(2)
450 .enumerate()
451 .map(|(k, i)| {
452 1f32 / (long_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32
453 })
454 .collect();
455 let inv_freq_len_long = inv_freq_long.len();
456 let inv_freq_long = Tensor::from_vec(inv_freq_long, (1, inv_freq_len_long), dev)?;
457 let t_long = Tensor::arange(0u32, max_seq_len as u32, dev)?
458 .to_dtype(DType::F32)?
459 .reshape((max_seq_len, 1))?;
460 let freqs_long = t_long.matmul(&inv_freq_long)?;
461 let sin_long = (freqs_long.sin()?.to_dtype(dtype)? * long_mscale)?;
462 let cos_long = (freqs_long.cos()?.to_dtype(dtype)? * long_mscale)?;
463 Ok(Self {
464 short_cos: cos_short,
465 short_sin: sin_short,
466 long_cos: Some(cos_long),
467 long_sin: Some(sin_long),
468 original_max_position_embeddings: cfg.original_max_position_embeddings,
469 })
470 }
471
472 pub fn new(dtype: DType, cfg: impl Into<PhiRopeConfig>, dev: &Device) -> Result<Self> {
473 let cfg: PhiRopeConfig = cfg.into();
474
475 match &cfg.rope_scaling {
476 Some(PhiRopeScalingConfig::Classic {
477 short_factor,
478 long_factor,
479 scaling_type,
480 }) => {
481 Self::new_classic_scaled(short_factor, long_factor, scaling_type, &cfg, dtype, dev)
482 }
483
484 Some(PhiRopeScalingConfig::Scaled {
485 short_factor,
486 long_factor,
487 scaling_type,
488 long_mscale,
489 short_mscale,
490 }) => Self::new_scaled(
491 short_factor,
492 long_factor,
493 scaling_type,
494 *long_mscale,
495 *short_mscale,
496 &cfg,
497 dtype,
498 dev,
499 ),
500
501 None => Self::new_unscaled(&cfg, dtype, dev),
502 }
503 }
504
505 fn get_long_or_short_sin_cos(&self, position_ids: &[usize]) -> (&Tensor, &Tensor) {
507 if self.long_cos.is_none() {
508 return (&self.short_sin, &self.short_cos);
509 }
510 let seq_len = position_ids.iter().max().unwrap() + 1;
511 if seq_len > self.original_max_position_embeddings {
512 (
513 self.long_sin.as_ref().unwrap(),
514 self.long_cos.as_ref().unwrap(),
515 )
516 } else {
517 (&self.short_sin, &self.short_cos)
518 }
519 }
520
521 pub fn forward(
522 &self,
523 q: &Tensor,
524 k: &Tensor,
525 seqlen_offsets: &[usize],
526 position_ids: &[usize],
527 ) -> Result<(Tensor, Tensor)> {
528 let (sin, cos) = self.get_long_or_short_sin_cos(position_ids);
529 let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
530
531 let rot_dim = cos.dim(D::Minus1)? * 2;
532
533 if rot_dim != q.dim(D::Minus1)? {
535 let rot_dim = cos.dim(D::Minus1)? * 2;
536 let q_rot = q.narrow(D::Minus1, 0, rot_dim)?;
537 let q_pass = q.narrow(D::Minus1, rot_dim, q.dim(D::Minus1)? - rot_dim)?;
538 let k_rot = k.narrow(D::Minus1, 0, rot_dim)?;
539 let k_pass = k.narrow(D::Minus1, rot_dim, k.dim(D::Minus1)? - rot_dim)?;
540
541 let (q_rot, k_rot) = if seqlen_offsets.len() == 1 {
542 let cos = cos.narrow(0, seqlen_offsets[0], seq_len)?;
543 let sin = sin.narrow(0, seqlen_offsets[0], seq_len)?;
544 let q_embed = candle_nn::rotary_emb::rope(&q_rot.contiguous()?, &cos, &sin)?;
545 let k_embed = candle_nn::rotary_emb::rope(&k_rot.contiguous()?, &cos, &sin)?;
546 (q_embed, k_embed)
547 } else {
548 let mut q_embeds = Vec::new();
549 let mut k_embeds = Vec::new();
550 for (i, offset) in seqlen_offsets.iter().enumerate() {
551 let cos = cos.narrow(0, *offset, seq_len)?;
552 let sin = sin.narrow(0, *offset, seq_len)?;
553 let q_embed = candle_nn::rotary_emb::rope(
554 &q_rot.i(i)?.unsqueeze(0)?.contiguous()?,
555 &cos,
556 &sin,
557 )?;
558 let k_embed = candle_nn::rotary_emb::rope(
559 &k_rot.i(i)?.unsqueeze(0)?.contiguous()?,
560 &cos,
561 &sin,
562 )?;
563 q_embeds.push(q_embed);
564 k_embeds.push(k_embed);
565 }
566 let q_rot = Tensor::cat(&q_embeds, 0)?;
567 let k_rot = Tensor::cat(&k_embeds, 0)?;
568 (q_rot, k_rot)
569 };
570
571 Ok((
572 Tensor::cat(&[q_rot, q_pass], D::Minus1)?.contiguous()?,
573 Tensor::cat(&[k_rot, k_pass], D::Minus1)?.contiguous()?,
574 ))
575 } else if seqlen_offsets.len() == 1 {
576 let cos = cos.narrow(0, seqlen_offsets[0], seq_len)?;
577 let sin = sin.narrow(0, seqlen_offsets[0], seq_len)?;
578 let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
579 let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
580 Ok((q_embed, k_embed))
581 } else {
582 let mut q_embeds = Vec::new();
583 let mut k_embeds = Vec::new();
584 for (i, offset) in seqlen_offsets.iter().enumerate() {
585 let cos = cos.narrow(0, *offset, seq_len)?;
586 let sin = sin.narrow(0, *offset, seq_len)?;
587 let q_embed =
588 candle_nn::rotary_emb::rope(&q.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
589 let k_embed =
590 candle_nn::rotary_emb::rope(&k.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
591 q_embeds.push(q_embed);
592 k_embeds.push(k_embed);
593 }
594 Ok((Tensor::cat(&q_embeds, 0)?, Tensor::cat(&k_embeds, 0)?))
595 }
596 }
597}
598
599#[derive(Debug, Clone)]
601pub struct Llama3RotaryEmbedding(RotaryEmbedding);
602
603#[derive(Debug, Clone, Deserialize, Serialize, Default)]
604pub enum Llama3RopeType {
605 #[serde(rename = "llama3")]
606 Llama3,
607 #[default]
608 #[serde(rename = "default")]
609 Default,
610}
611
612#[derive(Debug, Clone, Deserialize, Serialize, Default)]
613pub struct Llama3RopeConfig {
614 pub factor: f32,
615 pub low_freq_factor: f32,
616 pub high_freq_factor: f32,
617 pub original_max_position_embeddings: usize,
618 pub rope_type: Llama3RopeType,
619}
620
621fn calculate_default_inv_freq(cfg: &llama::Config) -> Vec<f32> {
622 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
623 (0..head_dim)
624 .step_by(2)
625 .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
626 .collect()
627}
628
629fn calculate_default_inv_freq_llama4(cfg: &llama4::TextConfig) -> Vec<f32> {
630 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
631 (0..head_dim)
632 .step_by(2)
633 .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
634 .collect()
635}
636
637impl Llama3RotaryEmbedding {
639 pub fn new_llama3(
640 dtype: DType,
641 cfg: &llama::Config,
642 dev: &Device,
643 is_gpt_neox: bool,
644 ) -> Result<Self> {
645 match &cfg.rope_scaling {
646 None
647 | Some(Llama3RopeConfig {
648 rope_type: Llama3RopeType::Default,
649 ..
650 }) => Ok(Self(RotaryEmbedding::new(
651 cfg.rope_theta,
652 cfg.hidden_size / cfg.num_attention_heads,
653 cfg.max_position_embeddings,
654 dev,
655 is_gpt_neox,
656 dtype,
657 )?)),
658 Some(rope_scaling) => {
659 let low_freq_wavelen = rope_scaling.original_max_position_embeddings as f32
660 / rope_scaling.low_freq_factor;
661 let high_freq_wavelen = rope_scaling.original_max_position_embeddings as f32
662 / rope_scaling.high_freq_factor;
663
664 let inv_freq = calculate_default_inv_freq(cfg)
665 .into_iter()
666 .map(|freq| {
667 let wavelen = 2. * PI / freq;
668 if wavelen < high_freq_wavelen {
669 freq
670 } else if wavelen > low_freq_wavelen {
671 freq / rope_scaling.factor
672 } else {
673 let smooth = (rope_scaling.original_max_position_embeddings as f32
674 / wavelen
675 - rope_scaling.low_freq_factor)
676 / (rope_scaling.high_freq_factor - rope_scaling.low_freq_factor);
677 (1. - smooth) * freq / rope_scaling.factor + smooth * freq
678 }
679 })
680 .collect::<Vec<_>>();
681 let inv_freq_len = inv_freq.len();
682 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
683
684 let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
685 .to_dtype(DType::F32)?
686 .reshape((cfg.max_position_embeddings, 1))?;
687 let freqs = t.matmul(&inv_freq)?;
688 let sin = freqs.sin()?.to_dtype(dtype)?;
689 let cos = freqs.cos()?.to_dtype(dtype)?;
690 Ok(Self(RotaryEmbedding {
691 sin,
692 cos,
693 is_gpt_neox,
694 }))
695 }
696 }
697 }
698
699 pub fn new_llama4(
700 dtype: DType,
701 cfg: &llama4::TextConfig,
702 dev: &Device,
703 is_gpt_neox: bool,
704 ) -> Result<Self> {
705 match &cfg.rope_scaling {
706 None
707 | Some(Llama3RopeConfig {
708 rope_type: Llama3RopeType::Default,
709 ..
710 }) => Ok(Self(RotaryEmbedding::new(
711 cfg.rope_theta,
712 cfg.hidden_size / cfg.num_attention_heads,
713 cfg.max_position_embeddings,
714 dev,
715 is_gpt_neox,
716 dtype,
717 )?)),
718 Some(rope_scaling) => {
719 let low_freq_wavelen = rope_scaling.original_max_position_embeddings as f32
720 / rope_scaling.low_freq_factor;
721 let high_freq_wavelen = rope_scaling.original_max_position_embeddings as f32
722 / rope_scaling.high_freq_factor;
723
724 let inv_freq = calculate_default_inv_freq_llama4(cfg)
725 .into_iter()
726 .map(|freq| {
727 let wavelen = 2. * PI / freq;
728 if wavelen < high_freq_wavelen {
729 freq
730 } else if wavelen > low_freq_wavelen {
731 freq / rope_scaling.factor
732 } else {
733 let smooth = (rope_scaling.original_max_position_embeddings as f32
734 / wavelen
735 - rope_scaling.low_freq_factor)
736 / (rope_scaling.high_freq_factor - rope_scaling.low_freq_factor);
737 (1. - smooth) * freq / rope_scaling.factor + smooth * freq
738 }
739 })
740 .collect::<Vec<_>>();
741 let inv_freq_len = inv_freq.len();
742 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
743
744 let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
745 .to_dtype(DType::F32)?
746 .reshape((cfg.max_position_embeddings, 1))?;
747 let freqs = t.matmul(&inv_freq)?;
748 let sin = freqs.sin()?.to_dtype(dtype)?;
749 let cos = freqs.cos()?.to_dtype(dtype)?;
750 Ok(Self(RotaryEmbedding {
751 sin,
752 cos,
753 is_gpt_neox,
754 }))
755 }
756 }
757 }
758
759 pub fn new_mllama3(
760 dtype: DType,
761 cfg: &MLlamaTextConfig,
762 dev: &Device,
763 is_gpt_neox: bool,
764 ) -> Result<Self> {
765 match &cfg.rope_scaling {
766 None
767 | Some(MLlamaRopeScaling {
768 rope_type: MLlamaRopeType::Default,
769 ..
770 }) => Ok(Self(RotaryEmbedding::new(
771 cfg.rope_theta,
772 cfg.hidden_size / cfg.num_attention_heads,
773 cfg.max_position_embeddings,
774 dev,
775 is_gpt_neox,
776 dtype,
777 )?)),
778 Some(MLlamaRopeScaling {
779 rope_type: MLlamaRopeType::Llama3,
780 original_max_position_embeddings,
781 factor,
782 attention_factor: _,
783 beta_fast: _,
784 beta_slow: _,
785 short_factor: _,
786 long_factor: _,
787 low_freq_factor,
788 high_freq_factor,
789 }) => {
790 let factor = factor.context("MLlama Llama3 RoPE needs `factor` parameter.")?;
791 let low_freq_factor = low_freq_factor
792 .context("MLlama Llama3 RoPE needs `low_freq_factor` parameter.")?;
793 let high_freq_factor = high_freq_factor
794 .context("MLlama Llama3 RoPE needs `high_freq_factor` parameter.")?;
795
796 let low_freq_wavelen = *original_max_position_embeddings as f32 / low_freq_factor;
797 let high_freq_wavelen = *original_max_position_embeddings as f32 / high_freq_factor;
798
799 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
800
801 let inv_freq = (0..head_dim)
802 .step_by(2)
803 .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
804 .map(|freq| {
805 let wavelen = 2. * PI / freq;
806 if wavelen < high_freq_wavelen {
807 freq
808 } else if wavelen > low_freq_wavelen {
809 freq / factor
810 } else {
811 let smooth = (*original_max_position_embeddings as f32 / wavelen
812 - low_freq_factor)
813 / (high_freq_factor - low_freq_factor);
814 (1. - smooth) * freq / factor + smooth * freq
815 }
816 })
817 .collect::<Vec<_>>();
818 let inv_freq_len = inv_freq.len();
819 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
820
821 let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
822 .to_dtype(DType::F32)?
823 .reshape((cfg.max_position_embeddings, 1))?;
824 let freqs = t.matmul(&inv_freq)?;
825 let sin = freqs.sin()?.to_dtype(dtype)?;
826 let cos = freqs.cos()?.to_dtype(dtype)?;
827 Ok(Self(RotaryEmbedding {
828 sin,
829 cos,
830 is_gpt_neox,
831 }))
832 }
833 Some(MLlamaRopeScaling {
834 rope_type: other, ..
835 }) => {
836 candle_core::bail!(
837 "MLlama doesn't support any other RoPE type than `llama3`, got {other:?}"
838 )
839 }
840 }
841 }
842
843 pub fn forward(
844 &self,
845 q: &Tensor,
846 k: &Tensor,
847 seqlen_offsets: &[usize],
848 ) -> Result<(Tensor, Tensor)> {
849 self.0.forward(q, k, seqlen_offsets)
850 }
851}
852
853#[derive(Debug, Clone)]
855pub struct Qwen2VLRotaryEmbedding {
856 inv_freq: Tensor,
857 mrope_section: Vec<usize>,
858}
859
860impl Qwen2VLRotaryEmbedding {
861 pub fn new(
862 base: f32,
863 head_dim: usize,
864 device: &Device,
865 mrope_section: Vec<usize>,
866 ) -> Result<Self> {
867 let inv_freq: Vec<_> = (0..head_dim)
868 .step_by(2)
869 .map(|i| 1f32 / base.powf(i as f32 / head_dim as f32))
870 .collect();
871 let inv_freq_len = inv_freq.len();
872 let inv_freq = Tensor::from_vec(inv_freq, (inv_freq_len,), device)?.to_dtype(DType::F32)?;
873 Ok(Self {
874 inv_freq,
875 mrope_section,
876 })
877 }
878
879 pub fn compute_cos_sin(&self, position_ids: &Tensor, dtype: DType) -> Result<(Tensor, Tensor)> {
881 let inv_freq_expanded =
882 self.inv_freq
883 .reshape((1, 1, (), 1))?
884 .repeat((3, position_ids.dim(1)?, 1, 1))?;
885 let position_ids_expanded = position_ids.unsqueeze(2)?;
886 let freqs = inv_freq_expanded
887 .matmul(&position_ids_expanded.to_dtype(inv_freq_expanded.dtype())?)?
888 .transpose(2, 3)?;
889 let cos = freqs.cos()?;
890 let sin = freqs.sin()?;
891
892 let cos = Tensor::cat(
893 &cos.split(&self.mrope_section, D::Minus1)?
894 .into_iter()
895 .enumerate()
896 .map(|(i, m)| m.i(i % 3))
897 .collect::<Result<Vec<_>>>()?,
898 D::Minus1,
899 )?
900 .squeeze(0)?
901 .to_dtype(dtype)?
902 .contiguous()?;
903 let sin = Tensor::cat(
904 &sin.split(&self.mrope_section, D::Minus1)?
905 .into_iter()
906 .enumerate()
907 .map(|(i, m)| m.i(i % 3))
908 .collect::<Result<Vec<_>>>()?,
909 D::Minus1,
910 )?
911 .squeeze(0)?
912 .to_dtype(dtype)?
913 .contiguous()?;
914
915 Ok((cos, sin))
916 }
917
918 pub fn forward(
920 &self,
921 (cos, sin): &(Tensor, Tensor),
922 q: &mut Tensor,
923 k: &mut Tensor,
924 ) -> Result<()> {
925 *q = candle_nn::rotary_emb::rope(&q.contiguous()?, cos, sin)?;
926 *k = candle_nn::rotary_emb::rope(&k.contiguous()?, cos, sin)?;
927 Ok(())
928 }
929}
930
931#[derive(Debug, Clone)]
932pub struct Qwen2_5VLRotaryEmbedding {
933 inv_freq: Tensor,
934 mrope_section: Vec<usize>,
935}
936
937impl Qwen2_5VLRotaryEmbedding {
938 pub fn new(
939 base: f32,
940 head_dim: usize,
941 device: &Device,
942 mrope_section: Vec<usize>,
943 ) -> Result<Self> {
944 let inv_freq: Vec<_> = (0..head_dim)
945 .step_by(2)
946 .map(|i| 1f32 / base.powf(i as f32 / head_dim as f32))
947 .collect();
948 let inv_freq_len = inv_freq.len();
949 let inv_freq = Tensor::from_vec(inv_freq, (inv_freq_len,), device)?.to_dtype(DType::F32)?;
950 Ok(Self {
951 inv_freq,
952 mrope_section,
953 })
954 }
955
956 pub fn compute_cos_sin(&self, position_ids: &Tensor, dtype: DType) -> Result<(Tensor, Tensor)> {
958 let inv_freq_expanded =
959 self.inv_freq
960 .reshape((1, 1, (), 1))?
961 .repeat((3, position_ids.dim(1)?, 1, 1))?;
962 let position_ids_expanded = position_ids.unsqueeze(2)?;
963 let freqs = inv_freq_expanded
964 .matmul(&position_ids_expanded.to_dtype(inv_freq_expanded.dtype())?)?
965 .transpose(2, 3)?;
966 let cos = freqs.cos()?;
967 let sin = freqs.sin()?;
968
969 let cos = Tensor::cat(
970 &cos.split(&self.mrope_section, D::Minus1)?
971 .into_iter()
972 .enumerate()
973 .map(|(i, m)| m.i(i % 3))
974 .collect::<Result<Vec<_>>>()?,
975 D::Minus1,
976 )?
977 .squeeze(0)?
978 .to_dtype(dtype)?
979 .contiguous()?;
980 let sin = Tensor::cat(
981 &sin.split(&self.mrope_section, D::Minus1)?
982 .into_iter()
983 .enumerate()
984 .map(|(i, m)| m.i(i % 3))
985 .collect::<Result<Vec<_>>>()?,
986 D::Minus1,
987 )?
988 .squeeze(0)?
989 .to_dtype(dtype)?
990 .contiguous()?;
991
992 Ok((cos, sin))
993 }
994
995 pub fn forward(
996 &self,
997 (cos, sin): &(Tensor, Tensor),
998 q: &mut Tensor,
999 k: &mut Tensor,
1000 ) -> Result<()> {
1001 *q = candle_nn::rotary_emb::rope(&q.contiguous()?, cos, sin)?;
1002 *k = candle_nn::rotary_emb::rope(&k.contiguous()?, cos, sin)?;
1003 Ok(())
1004 }
1005}
1006
1007#[derive(Debug, Clone)]
1008pub struct DeepSeekV2RotaryEmbedding {
1009 sin: Tensor,
1010 cos: Tensor,
1011}
1012
1013#[derive(Debug, Clone, Deserialize, Serialize)]
1014#[serde(untagged)]
1015pub enum DeepSeekV2RopeScaling {
1016 Yarn {
1017 original_max_position_embeddings: usize,
1018 beta_fast: f32,
1019 beta_slow: f32,
1020 mscale: f32,
1021 mscale_all_dim: f32,
1022 factor: f32,
1023 #[serde(rename = "type")]
1024 scaling_type: ScaledRopeType,
1025 },
1026 LinearOrDynamic {
1027 #[serde(rename = "type")]
1028 scaling_type: ScaledRopeType,
1029 factor: f64,
1030 },
1031}
1032
1033pub struct DeepSeekV2RopeConfig {
1034 pub rope_scaling: Option<DeepSeekV2RopeScaling>,
1035 pub max_position_embeddings: usize,
1036 pub rope_theta: f32,
1037 pub qk_rope_head_dim: usize,
1038}
1039
1040impl DeepSeekV2RotaryEmbedding {
1041 fn new_unscaled(cfg: &DeepSeekV2RopeConfig, dtype: DType, dev: &Device) -> Result<Self> {
1042 let max_seq_len = cfg.max_position_embeddings;
1043 let dim = cfg.qk_rope_head_dim;
1044
1045 let inv_freq: Vec<_> = (0..dim)
1046 .step_by(2)
1047 .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / dim as f32))
1048 .collect();
1049 let inv_freq_len = inv_freq.len();
1050 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
1051 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
1052 .to_dtype(DType::F32)?
1053 .reshape((max_seq_len, 1))?;
1054 let freqs = t.matmul(&inv_freq)?;
1055
1056 let sin = freqs.sin()?.to_dtype(dtype)?;
1057 let cos = freqs.cos()?.to_dtype(dtype)?;
1058
1059 Ok(Self { sin, cos })
1060 }
1061
1062 fn yarn_find_correction_dim(
1063 num_rot: f32,
1064 dim: usize,
1065 base: f32,
1066 max_position_embeddings: usize,
1067 ) -> f32 {
1068 (dim as f32 * (max_position_embeddings as f32 / (num_rot * 2. * PI)).ln())
1069 / (2. * base.ln())
1070 }
1071
1072 fn yarn_find_correction_range(
1073 low_rot: f32,
1074 high_rot: f32,
1075 dim: usize,
1076 base: f32,
1077 max_position_embeddings: usize,
1078 ) -> (f32, f32) {
1079 let low =
1080 Self::yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings).floor();
1081 let high =
1082 Self::yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings).ceil();
1083 (low.max(0.), high.min(dim as f32 - 1.))
1084 }
1085
1086 fn yarn_linear_ramp_mask(min: f32, mut max: f32, dim: usize, dev: &Device) -> Result<Tensor> {
1087 if min == max {
1088 max += 0.001;
1090 }
1091 let linear_func =
1092 ((Tensor::arange(0f32, dim as f32, dev)? - min as f64)? / (max as f64 - min as f64))?;
1093 linear_func.clamp(0., 1)
1094 }
1095
1096 pub(crate) fn yarn_get_mscale(scale: f32, mscale: f32) -> f32 {
1097 if scale <= 1. {
1098 return 1.;
1099 }
1100 0.1 * mscale * scale.ln() + 1.
1101 }
1102
1103 #[allow(clippy::too_many_arguments)]
1104 fn new_yarn(
1105 cfg: &DeepSeekV2RopeConfig,
1106 dtype: DType,
1107 dev: &Device,
1108 original_max_position_embeddings: usize,
1109 beta_fast: f32,
1110 beta_slow: f32,
1111 factor: f32,
1112 mscale: f32,
1113 mscale_all_dim: f32,
1114 ) -> Result<Self> {
1115 let freq_extra: Vec<_> = (0..cfg.qk_rope_head_dim)
1116 .step_by(2)
1117 .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / cfg.qk_rope_head_dim as f32))
1118 .collect();
1119 let freq_extra_len = freq_extra.len();
1120 let freq_extra = Tensor::from_vec(freq_extra, freq_extra_len, dev)?;
1121 let freq_inter: Vec<_> = (0..cfg.qk_rope_head_dim)
1122 .step_by(2)
1123 .map(|i| 1f32 / (factor * cfg.rope_theta.powf(i as f32 / cfg.qk_rope_head_dim as f32)))
1124 .collect();
1125 let freq_inter_len = freq_inter.len();
1126 let freq_inter = Tensor::from_vec(freq_inter, (1, freq_inter_len), dev)?;
1127
1128 let (low, high) = Self::yarn_find_correction_range(
1129 beta_fast,
1130 beta_slow,
1131 cfg.qk_rope_head_dim,
1132 cfg.rope_theta,
1133 original_max_position_embeddings,
1134 );
1135 let inv_freq_mask =
1136 (1. - Self::yarn_linear_ramp_mask(low, high, cfg.qk_rope_head_dim / 2, dev)?)?;
1137 let inv_freq = freq_inter
1138 .broadcast_mul(&(1. - &inv_freq_mask)?)?
1139 .broadcast_add(&freq_extra.broadcast_mul(&inv_freq_mask)?)?;
1140
1141 let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
1142 .to_dtype(DType::F32)?
1143 .reshape((cfg.max_position_embeddings, 1))?;
1144 let freqs = t.matmul(&inv_freq)?;
1145
1146 let mscale =
1147 Self::yarn_get_mscale(factor, mscale) / Self::yarn_get_mscale(factor, mscale_all_dim);
1148 let sin = (freqs.sin()? * mscale as f64)?.to_dtype(dtype)?;
1149 let cos = (freqs.cos()? * mscale as f64)?.to_dtype(dtype)?;
1150
1151 Ok(Self { sin, cos })
1152 }
1153
1154 pub fn new(cfg: &DeepSeekV2RopeConfig, dtype: DType, dev: &Device) -> Result<Self> {
1155 match &cfg.rope_scaling {
1156 Some(DeepSeekV2RopeScaling::LinearOrDynamic {
1157 scaling_type: _,
1158 factor: _,
1159 }) => candle_core::bail!("linear and dynamic rope are not implemented yet!"),
1160 Some(DeepSeekV2RopeScaling::Yarn {
1161 original_max_position_embeddings,
1162 beta_fast,
1163 beta_slow,
1164 factor,
1165 mscale,
1166 mscale_all_dim,
1167 scaling_type: _,
1168 }) => Self::new_yarn(
1169 cfg,
1170 dtype,
1171 dev,
1172 *original_max_position_embeddings,
1173 *beta_fast,
1174 *beta_slow,
1175 *factor,
1176 *mscale,
1177 *mscale_all_dim,
1178 ),
1179 None => Self::new_unscaled(cfg, dtype, dev),
1180 }
1181 }
1182
1183 pub fn forward(
1184 &self,
1185 q: &Tensor,
1186 k: &Tensor,
1187 seqlen_offsets: &[usize],
1188 ) -> Result<(Tensor, Tensor)> {
1189 let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
1190
1191 if seqlen_offsets.len() == 1 {
1192 let cos = self.cos.narrow(0, seqlen_offsets[0], seq_len)?;
1193 let sin = self.sin.narrow(0, seqlen_offsets[0], seq_len)?;
1194 let q_embed = candle_nn::rotary_emb::rope_i(&q.contiguous()?, &cos, &sin)?;
1195 let k_embed = candle_nn::rotary_emb::rope_i(&k.contiguous()?, &cos, &sin)?;
1196 Ok((q_embed, k_embed))
1197 } else {
1198 let mut q_embeds = Vec::new();
1199 let mut k_embeds = Vec::new();
1200 for (i, offset) in seqlen_offsets.iter().enumerate() {
1201 let cos = self.cos.narrow(0, *offset, seq_len)?;
1202 let sin = self.sin.narrow(0, *offset, seq_len)?;
1203 let q_embed = candle_nn::rotary_emb::rope_i(
1204 &q.i(i)?.unsqueeze(0)?.contiguous()?,
1205 &cos,
1206 &sin,
1207 )?;
1208 let k_embed = candle_nn::rotary_emb::rope_i(
1209 &k.i(i)?.unsqueeze(0)?.contiguous()?,
1210 &cos,
1211 &sin,
1212 )?;
1213 q_embeds.push(q_embed);
1214 k_embeds.push(k_embed);
1215 }
1216 Ok((Tensor::cat(&q_embeds, 0)?, Tensor::cat(&k_embeds, 0)?))
1217 }
1218 }
1219}
1220
1221#[derive(Debug, Clone)]
1222pub struct Phi4MMRotaryEmbedding {
1223 short_sin: Tensor,
1224 short_cos: Tensor,
1225 long_cos: Option<Tensor>,
1226 long_sin: Option<Tensor>,
1227 original_max_position_embeddings: usize,
1228}
1229
1230#[derive(Debug, Clone, Default, Deserialize, Serialize)]
1231#[serde(rename_all = "lowercase")]
1232pub enum Phi4MMScaledRopeType {
1233 #[serde(alias = "longrope")]
1234 LongRope,
1235 #[default]
1236 Default,
1237}
1238
1239#[derive(Debug, Clone, Deserialize, Serialize)]
1240pub struct Phi4MMRopeScalingConfig {
1241 short_factor: Option<Vec<f64>>,
1242 long_factor: Option<Vec<f64>>,
1243 #[serde(rename = "type")]
1244 scaling_type: Phi4MMScaledRopeType,
1245}
1246
1247impl Phi4MMRotaryEmbedding {
1248 fn new_unscaled(cfg: &Phi4MMConfig, dtype: DType, dev: &Device) -> Result<Self> {
1249 let max_seq_len = cfg.max_position_embeddings;
1250 let dim = (cfg.head_dim() as f64 * cfg.partial_rotary_factor) as usize;
1251
1252 let inv_freq: Vec<_> = (0..dim)
1253 .step_by(2)
1254 .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
1255 .collect();
1256 let inv_freq_len = inv_freq.len();
1257 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
1258 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
1259 .to_dtype(DType::F32)?
1260 .reshape((max_seq_len, 1))?;
1261 let freqs = t.matmul(&inv_freq)?;
1262 let sin = freqs.sin()?.to_dtype(dtype)?;
1263 let cos = freqs.cos()?.to_dtype(dtype)?;
1264 Ok(Self {
1265 short_cos: cos,
1266 short_sin: sin,
1267 long_cos: None,
1268 long_sin: None,
1269 original_max_position_embeddings: cfg.original_max_position_embeddings,
1270 })
1271 }
1272
1273 #[allow(clippy::too_many_arguments)]
1274 fn new_longrope(
1275 short_factor: &[f64],
1276 long_factor: &[f64],
1277 cfg: &Phi4MMConfig,
1278 dtype: DType,
1279 dev: &Device,
1280 ) -> Result<Self> {
1281 let max_seq_len = cfg.max_position_embeddings;
1282 let dim = (cfg.head_dim() as f64 * cfg.partial_rotary_factor) as usize;
1283
1284 let scale =
1286 cfg.max_position_embeddings as f64 / cfg.original_max_position_embeddings as f64;
1287 let scaling_factor = if scale <= 1.0 {
1288 1.0
1289 } else {
1290 (1.0 + scale.ln() / (cfg.original_max_position_embeddings as f64).ln()).sqrt()
1291 };
1292
1293 let inv_freq_short: Vec<_> = (0..dim)
1295 .step_by(2)
1296 .enumerate()
1297 .map(|(k, i)| {
1298 1f32 / (short_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32
1299 })
1300 .collect();
1301 let inv_freq_len_short = inv_freq_short.len();
1302 let inv_freq_short = Tensor::from_vec(inv_freq_short, (1, inv_freq_len_short), dev)?;
1303 let t_short = Tensor::arange(0u32, max_seq_len as u32, dev)?
1304 .to_dtype(DType::F32)?
1305 .reshape((max_seq_len, 1))?;
1306 let freqs_short = t_short.matmul(&inv_freq_short)?;
1307 let sin_short = (freqs_short.sin()?.to_dtype(dtype)? * scaling_factor)?;
1308 let cos_short = (freqs_short.cos()?.to_dtype(dtype)? * scaling_factor)?;
1309
1310 let inv_freq_long: Vec<_> = (0..dim)
1312 .step_by(2)
1313 .enumerate()
1314 .map(|(k, i)| {
1315 1f32 / (long_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32
1316 })
1317 .collect();
1318 let inv_freq_len_long = inv_freq_long.len();
1319 let inv_freq_long = Tensor::from_vec(inv_freq_long, (1, inv_freq_len_long), dev)?;
1320 let t_long = Tensor::arange(0u32, max_seq_len as u32, dev)?
1321 .to_dtype(DType::F32)?
1322 .reshape((max_seq_len, 1))?;
1323 let freqs_long = t_long.matmul(&inv_freq_long)?;
1324 let sin_long = (freqs_long.sin()?.to_dtype(dtype)? * scaling_factor)?;
1325 let cos_long = (freqs_long.cos()?.to_dtype(dtype)? * scaling_factor)?;
1326
1327 Ok(Self {
1328 short_cos: cos_short,
1329 short_sin: sin_short,
1330 long_cos: Some(cos_long),
1331 long_sin: Some(sin_long),
1332 original_max_position_embeddings: cfg.original_max_position_embeddings,
1333 })
1334 }
1335
1336 pub fn new(dtype: DType, cfg: &Phi4MMConfig, dev: &Device) -> Result<Self> {
1337 match &cfg.rope_scaling {
1338 Some(Phi4MMRopeScalingConfig {
1339 scaling_type: Phi4MMScaledRopeType::LongRope,
1340 short_factor: Some(short_factor),
1341 long_factor: Some(long_factor),
1342 }) => Self::new_longrope(short_factor, long_factor, cfg, dtype, dev),
1343
1344 _ => Self::new_unscaled(cfg, dtype, dev),
1345 }
1346 }
1347
1348 fn get_long_or_short_sin_cos(&self, position_ids: &[usize]) -> (&Tensor, &Tensor) {
1350 if self.long_cos.is_none() {
1351 return (&self.short_sin, &self.short_cos);
1352 }
1353 let seq_len = position_ids.iter().max().unwrap() + 1;
1354 if seq_len > self.original_max_position_embeddings {
1355 (
1356 self.long_sin.as_ref().unwrap(),
1357 self.long_cos.as_ref().unwrap(),
1358 )
1359 } else {
1360 (&self.short_sin, &self.short_cos)
1361 }
1362 }
1363
1364 pub fn forward(
1365 &self,
1366 q: &Tensor,
1367 k: &Tensor,
1368 seqlen_offsets: &[usize],
1369 position_ids: &[usize],
1370 ) -> Result<(Tensor, Tensor)> {
1371 let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
1372 let (sin, cos) = self.get_long_or_short_sin_cos(position_ids);
1373
1374 let rot_dim = cos.dim(D::Minus1)? * 2;
1375 let q_rot = q.narrow(D::Minus1, 0, rot_dim)?;
1376 let q_pass = q.narrow(D::Minus1, rot_dim, q.dim(D::Minus1)? - rot_dim)?;
1377 let k_rot = k.narrow(D::Minus1, 0, rot_dim)?;
1378 let k_pass = k.narrow(D::Minus1, rot_dim, k.dim(D::Minus1)? - rot_dim)?;
1379
1380 let (q_rot, k_rot) = if seqlen_offsets.len() == 1 {
1381 let cos = cos.narrow(0, seqlen_offsets[0], seq_len)?;
1382 let sin = sin.narrow(0, seqlen_offsets[0], seq_len)?;
1383 let q_embed = candle_nn::rotary_emb::rope(&q_rot.contiguous()?, &cos, &sin)?;
1384 let k_embed = candle_nn::rotary_emb::rope(&k_rot.contiguous()?, &cos, &sin)?;
1385 (q_embed, k_embed)
1386 } else {
1387 let mut q_embeds = Vec::new();
1388 let mut k_embeds = Vec::new();
1389 for (i, offset) in seqlen_offsets.iter().enumerate() {
1390 let cos = cos.narrow(0, *offset, seq_len)?;
1391 let sin = sin.narrow(0, *offset, seq_len)?;
1392 let q_embed = candle_nn::rotary_emb::rope(
1393 &q_rot.i(i)?.unsqueeze(0)?.contiguous()?,
1394 &cos,
1395 &sin,
1396 )?;
1397 let k_embed = candle_nn::rotary_emb::rope(
1398 &k_rot.i(i)?.unsqueeze(0)?.contiguous()?,
1399 &cos,
1400 &sin,
1401 )?;
1402 q_embeds.push(q_embed);
1403 k_embeds.push(k_embed);
1404 }
1405 let q_rot = Tensor::cat(&q_embeds, 0)?;
1406 let k_rot = Tensor::cat(&k_embeds, 0)?;
1407 (q_rot, k_rot)
1408 };
1409
1410 Ok((
1411 Tensor::cat(&[q_rot, q_pass], D::Minus1)?.contiguous()?,
1412 Tensor::cat(&[k_rot, k_pass], D::Minus1)?.contiguous()?,
1413 ))
1414 }
1415}
1416
1417#[derive(Debug, Clone)]
1418pub struct Gemma3RotaryEmbedding(RotaryEmbedding);
1419
1420#[derive(Debug, Clone, Deserialize, Serialize)]
1421#[serde(rename_all = "lowercase")]
1422pub enum Gemmma3ScaledRopeType {
1423 #[serde(alias = "linear")]
1424 Linear,
1425}
1426
1427#[derive(Debug, Clone, Deserialize, Serialize)]
1428pub struct Gemma3RopeScalingConfig {
1429 factor: f64,
1430 rope_type: Gemmma3ScaledRopeType,
1431}
1432
1433impl Gemma3RotaryEmbedding {
1434 fn new_linear(
1435 cfg: &Gemma3TextConfig,
1436 factor: f64,
1437 is_gpt_neox: bool,
1438 dtype: DType,
1439 dev: &Device,
1440 ) -> Result<Self> {
1441 let max_seq_len = cfg.max_position_embeddings;
1442 let dim = cfg.head_dim;
1443
1444 let inv_freq: Vec<_> = (0..dim)
1445 .step_by(2)
1446 .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
1447 .collect();
1448 let inv_freq_len = inv_freq.len();
1449 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
1450 let inv_freq = (inv_freq / factor)?;
1451
1452 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
1453 .to_dtype(DType::F32)?
1454 .reshape((max_seq_len, 1))?;
1455 let freqs = t.matmul(&inv_freq)?;
1456 let sin = freqs.sin()?.to_dtype(dtype)?;
1457 let cos = freqs.cos()?.to_dtype(dtype)?;
1458 Ok(Self(RotaryEmbedding {
1459 cos,
1460 sin,
1461 is_gpt_neox,
1462 }))
1463 }
1464
1465 pub fn new(
1466 is_gpt_neox: bool,
1467 dtype: DType,
1468 cfg: &Gemma3TextConfig,
1469 dev: &Device,
1470 ) -> Result<Self> {
1471 match &cfg.rope_scaling {
1472 Some(Gemma3RopeScalingConfig {
1473 rope_type: Gemmma3ScaledRopeType::Linear,
1474 factor,
1475 }) => Self::new_linear(cfg, *factor, is_gpt_neox, dtype, dev),
1476
1477 _ => Self::new_linear(cfg, 1.0, is_gpt_neox, dtype, dev),
1478 }
1479 }
1480
1481 pub fn forward(
1482 &self,
1483 q: &Tensor,
1484 k: &Tensor,
1485 seqlen_offsets: &[usize],
1486 ) -> Result<(Tensor, Tensor)> {
1487 self.0.forward(q, k, seqlen_offsets)
1488 }
1489}
1490
1491#[derive(Debug, Clone)]
1492pub struct QLinear {
1493 inner: QMatMul,
1494 bias: Option<Tensor>,
1495 dtype: DType,
1496}
1497
1498impl QLinear {
1499 pub fn new<R: std::io::Read + std::io::Seek>(
1500 ct: &mut Content<'_, R>,
1501 name: &str,
1502 device: &Device,
1503 ) -> Result<Self> {
1504 let w = ct.tensor(&format!("{name}.weight"), device)?;
1505 let b = ct.tensor(&format!("{name}.bias"), device)?;
1506 let inner = QMatMul::from_qtensor(w)?;
1507 let bias = b.dequantize(device)?;
1508 Ok(Self {
1509 inner,
1510 bias: Some(bias),
1511 dtype: DType::F32,
1512 })
1513 }
1514
1515 pub fn from_linear(linear: Linear) -> Self {
1516 Self {
1517 inner: QMatMul::Tensor(linear.weight().clone()),
1518 bias: linear.bias().cloned(),
1519 dtype: linear.weight().dtype(),
1520 }
1521 }
1522
1523 pub fn from_parts(w: Tensor, b: Option<Tensor>) -> Self {
1524 let dtype = w.dtype();
1525 Self {
1526 inner: QMatMul::Tensor(w),
1527 bias: b,
1528 dtype,
1529 }
1530 }
1531
1532 pub fn from_qparts(w: QTensor, b: Option<Tensor>) -> Self {
1533 if let Some(ref b) = b {
1534 assert_eq!(b.dtype(), DType::F32);
1535 }
1536 Self {
1537 inner: QMatMul::QTensor(Arc::new(w)),
1538 bias: b,
1539 dtype: DType::F32,
1540 }
1541 }
1542
1543 pub fn from_old_and_qmatmul(inner: QMatMul, old: &Self) -> Self {
1544 Self {
1545 inner,
1546 bias: old.bias.clone(),
1547 dtype: old.dtype,
1548 }
1549 }
1550
1551 pub fn inner(&mut self) -> &mut QMatMul {
1552 &mut self.inner
1553 }
1554
1555 pub fn inner_ref(&self) -> &QMatMul {
1556 &self.inner
1557 }
1558
1559 pub fn is_quant(&self) -> bool {
1560 matches!(self.inner, QMatMul::QTensor(_))
1561 }
1562
1563 pub fn bias(&self) -> Option<&Tensor> {
1564 self.bias.as_ref()
1565 }
1566
1567 pub fn bias_mut(&mut self) -> Option<&mut Tensor> {
1568 self.bias.as_mut()
1569 }
1570}
1571
1572impl Module for QLinear {
1573 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
1574 let xs = if self.is_quant() {
1575 xs.to_dtype(DType::F32)?
1576 } else {
1577 xs.clone()
1578 };
1579 if let Some(bias) = &self.bias {
1580 self.inner
1581 .forward(&xs)?
1582 .broadcast_add(bias)?
1583 .to_dtype(self.dtype)
1584 } else {
1585 self.inner.forward(&xs)?.to_dtype(self.dtype)
1586 }
1587 }
1588}
1589
1590#[derive(Debug, Clone)]
1591pub struct RotaryEmbedding {
1592 cos: Tensor,
1593 sin: Tensor,
1594 is_gpt_neox: bool,
1595}
1596
1597impl RotaryEmbedding {
1598 pub fn new(
1599 base: f32,
1600 head_dim: usize,
1601 max_position_embeddings: usize,
1602 device: &Device,
1603 is_gpt_neox: bool,
1604 dtype: DType,
1605 ) -> Result<Self> {
1606 let inv_freq: Vec<_> = (0..head_dim)
1607 .step_by(2)
1608 .map(|i| 1f32 / base.powf(i as f32 / head_dim as f32))
1609 .collect();
1610 let inv_freq_len = inv_freq.len();
1611 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), device)?;
1612 let t = Tensor::arange(0u32, max_position_embeddings as u32, device)?
1613 .to_dtype(DType::F32)?
1614 .reshape((max_position_embeddings, 1))?;
1615 let freqs = t.matmul(&inv_freq)?;
1616 let sin = freqs.sin()?.to_dtype(dtype)?;
1617 let cos = freqs.cos()?.to_dtype(dtype)?;
1618
1619 Ok(Self {
1620 cos,
1621 sin,
1622 is_gpt_neox,
1623 })
1624 }
1625
1626 pub fn new_partial(
1627 base: f32,
1628 rot_dim: usize,
1629 max_position_embeddings: usize,
1630 device: &Device,
1631 is_gpt_neox: bool,
1632 dtype: DType,
1633 ) -> Result<Self> {
1634 let inv_freq: Vec<_> = (0..rot_dim)
1635 .step_by(2)
1636 .map(|i| 1f32 / base.powf(i as f32 / rot_dim as f32))
1637 .collect();
1638 let inv_freq_len = inv_freq.len();
1639 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), device)?;
1640 let t = Tensor::arange(0u32, max_position_embeddings as u32, device)?
1641 .to_dtype(DType::F32)?
1642 .reshape((max_position_embeddings, 1))?;
1643 let freqs = t.matmul(&inv_freq)?;
1644 let sin = freqs.sin()?.to_dtype(dtype)?;
1645 let cos = freqs.cos()?.to_dtype(dtype)?;
1646
1647 Ok(Self {
1648 cos,
1649 sin,
1650 is_gpt_neox,
1651 })
1652 }
1653
1654 pub fn forward(
1655 &self,
1656 q: &Tensor,
1657 k: &Tensor,
1658 seqlen_offsets: &[usize],
1659 ) -> Result<(Tensor, Tensor)> {
1660 let (b_sz, qh, seq_len, n_embd) = q.dims4()?;
1661 let (_b_sz, kh, _seq_len, __n_embd) = k.dims4()?;
1662
1663 let rope = if self.is_gpt_neox {
1664 candle_nn::rotary_emb::rope
1665 } else {
1666 candle_nn::rotary_emb::rope_i
1667 };
1668
1669 if cfg!(feature = "cuda") && qh == kh {
1670 let (cos, sin) = if seqlen_offsets.len() == 1 {
1671 (
1672 self.cos.narrow(0, seqlen_offsets[0], seq_len)?,
1673 self.sin.narrow(0, seqlen_offsets[0], seq_len)?,
1674 )
1675 } else {
1676 let mut cos_s = Vec::new();
1677 let mut sin_s = Vec::new();
1678 for offset in seqlen_offsets {
1679 cos_s.push(self.cos.narrow(0, *offset, seq_len)?);
1680 sin_s.push(self.sin.narrow(0, *offset, seq_len)?);
1681 }
1682 (Tensor::cat(&cos_s, 0)?, Tensor::cat(&sin_s, 0)?)
1683 };
1684
1685 let q_embed = q.transpose(1, 2)?.flatten(0, 1)?;
1686 let k_embed = k.transpose(1, 2)?.flatten(0, 1)?;
1687 mistralrs_quant::rotary::apply_rotary_inplace(
1688 &q_embed,
1689 &k_embed,
1690 &cos,
1691 &sin,
1692 self.is_gpt_neox,
1693 )?;
1694 let mut q = q_embed
1695 .reshape((b_sz, seq_len, qh, n_embd))?
1696 .transpose(1, 2)?;
1697 let mut k = k_embed
1698 .reshape((b_sz, seq_len, kh, n_embd))?
1699 .transpose(1, 2)?;
1700 if !(cfg!(feature = "flash-attn") || cfg!(feature = "flash-attn-v3")) {
1701 q = q.contiguous()?;
1702 k = k.contiguous()?;
1703 }
1704 Ok((q, k))
1705 } else if seqlen_offsets.len() == 1 {
1706 let cos = self.cos.narrow(0, seqlen_offsets[0], seq_len)?;
1707 let sin = self.sin.narrow(0, seqlen_offsets[0], seq_len)?;
1708 let q_embed = rope(&q.contiguous()?, &cos, &sin)?;
1709 let k_embed = rope(&k.contiguous()?, &cos, &sin)?;
1710 Ok((q_embed, k_embed))
1711 } else {
1712 let mut q_embeds = Vec::new();
1713 let mut k_embeds = Vec::new();
1714 for (i, offset) in seqlen_offsets.iter().enumerate() {
1715 let cos = self.cos.narrow(0, *offset, seq_len)?;
1716 let sin = self.sin.narrow(0, *offset, seq_len)?;
1717 let q_embed = rope(&q.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
1718 let k_embed = rope(&k.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
1719 q_embeds.push(q_embed);
1720 k_embeds.push(k_embed);
1721 }
1722 Ok((Tensor::cat(&q_embeds, 0)?, Tensor::cat(&k_embeds, 0)?))
1723 }
1724 }
1725}
1726
1727#[derive(Debug, Clone, Copy, PartialEq, Deserialize, Serialize, Default)]
1728#[serde(rename_all = "lowercase")]
1729pub enum Activation {
1730 #[default]
1731 #[serde(alias = "gelu")]
1732 Gelu,
1733 #[serde(alias = "gelu_new")]
1734 NewGelu,
1735 Relu,
1736 Relu2,
1737 Relu6,
1738 Silu,
1739 Sigmoid,
1740 HardSigmoid,
1741 Swiglu,
1742 Swish,
1743 HardSwish,
1744 Elu(f64),
1745 LeakyRelu(f64),
1746 #[serde(alias = "gelu_pytorch_tanh")]
1747 GeluPytorchTanh,
1748 QuickGelu,
1749}
1750
1751impl Module for Activation {
1752 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
1753 match self {
1754 Self::Gelu => xs.gelu_erf(),
1755 Self::NewGelu => xs.gelu(),
1757 Self::Relu => xs.relu(),
1758 Self::Relu2 => xs.relu()?.sqr(),
1759 Self::Relu6 => xs.clamp(0f32, 6f32),
1760 Self::Silu => xs.silu(),
1761 Self::Sigmoid => candle_nn::ops::sigmoid(xs),
1762 Self::HardSigmoid => candle_nn::ops::hard_sigmoid(xs),
1763 Self::Swiglu => candle_nn::ops::swiglu(xs),
1764 Self::Swish => xs * candle_nn::ops::sigmoid(xs)?,
1765 Self::HardSwish => xs * candle_nn::ops::hard_sigmoid(xs)?,
1766 &Self::Elu(alpha) => xs.elu(alpha),
1767 &Self::LeakyRelu(negative_slope) => candle_nn::ops::leaky_relu(xs, negative_slope),
1768 Self::GeluPytorchTanh => xs.gelu(),
1769 Self::QuickGelu => xs * candle_nn::ops::sigmoid(&(xs * 1.702f64)?),
1770 }
1771 }
1772}
1773
1774impl TryInto<candle_nn::Activation> for Activation {
1775 type Error = candle_core::Error;
1776
1777 fn try_into(self) -> Result<candle_nn::Activation> {
1778 match self {
1779 Self::Gelu => Ok(candle_nn::Activation::Gelu),
1780 Self::Relu => Ok(candle_nn::Activation::Relu),
1781 Self::Silu => Ok(candle_nn::Activation::Silu),
1782 Self::NewGelu => Ok(candle_nn::Activation::NewGelu),
1783 Self::Relu2 => Ok(candle_nn::Activation::Relu2),
1784 Self::Relu6 => Ok(candle_nn::Activation::Relu6),
1785 Self::Sigmoid => Ok(candle_nn::Activation::Sigmoid),
1786 Self::HardSigmoid => Ok(candle_nn::Activation::HardSigmoid),
1787 Self::Swiglu => Ok(candle_nn::Activation::Swiglu),
1788 Self::Swish => Ok(candle_nn::Activation::Swish),
1789 Self::HardSwish => Ok(candle_nn::Activation::HardSwish),
1790 Self::Elu(x) => Ok(candle_nn::Activation::Elu(x)),
1791 Self::LeakyRelu(x) => Ok(candle_nn::Activation::LeakyRelu(x)),
1792 Self::GeluPytorchTanh => Ok(candle_nn::Activation::GeluPytorchTanh),
1793 Self::QuickGelu => candle_core::bail!("No mapping to candle_nn for QuickGelu"),
1794 }
1795 }
1796}
1797
1798#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1799pub struct Conv3dConfig {
1800 pub padding: usize,
1801 pub stride: usize,
1802 pub dilation: usize,
1803 pub groups: usize,
1804}
1805
1806impl Default for Conv3dConfig {
1807 fn default() -> Self {
1808 Self {
1809 padding: 0,
1810 stride: 1,
1811 dilation: 1,
1812 groups: 1,
1813 }
1814 }
1815}
1816
1817pub struct Conv3dNoBias {
1818 conv2d_1: Conv2d,
1819 conv2d_2: Conv2d,
1820}
1821
1822impl Conv3dNoBias {
1823 pub fn new(
1824 in_channels: usize,
1825 out_channels: usize,
1826 kernel_sizes: [usize; 3],
1827 cfg: Conv3dConfig,
1828 vb: ShardedVarBuilder,
1829 ) -> Result<Self> {
1830 let ws = vb.get(
1831 (
1832 out_channels,
1833 in_channels / cfg.groups,
1834 kernel_sizes[0],
1835 kernel_sizes[1],
1836 kernel_sizes[2],
1837 ),
1838 "weight",
1839 )?;
1840
1841 let w1 = ws.i((.., .., 0, .., ..))?;
1845 let w2 = ws.i((.., .., 1, .., ..))?;
1846
1847 let cfg = Conv2dConfig {
1848 padding: cfg.padding,
1849 stride: cfg.stride,
1850 dilation: cfg.dilation,
1851 groups: cfg.groups,
1852 };
1853
1854 Ok(Self {
1855 conv2d_1: Conv2d::new(w1.contiguous()?, None, cfg),
1856 conv2d_2: Conv2d::new(w2.contiguous()?, None, cfg),
1857 })
1858 }
1859}
1860
1861impl Module for Conv3dNoBias {
1862 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
1863 let xs1 = xs.i((.., .., 0, .., ..))?;
1864 let xs2 = xs.i((.., .., 1, .., ..))?;
1865
1866 (self.conv2d_1.forward(&xs1)? + self.conv2d_2.forward(&xs2)?)?.unsqueeze(2)
1867 }
1868}
1869
1870pub trait TensorInfExtend {
1871 fn is_inf(&self) -> Result<Self>
1872 where
1873 Self: Sized;
1874 fn any(&self) -> Result<bool>;
1875}
1876
1877impl TensorInfExtend for Tensor {
1878 fn is_inf(&self) -> Result<Self> {
1879 self.broadcast_eq(&Tensor::new(f32::INFINITY, self.device())?.to_dtype(self.dtype())?)
1880 }
1881
1882 fn any(&self) -> Result<bool> {
1883 let sum = self.sum_all()?;
1884 match self.dtype() {
1885 DType::U8 => Ok(sum.to_scalar::<u8>()? == 0),
1886 DType::U32 => Ok(sum.to_scalar::<u32>()? == 0),
1887 DType::I16 => Ok(sum.to_scalar::<i16>()? == 0),
1888 DType::I32 => Ok(sum.to_scalar::<i32>()? == 0),
1889 DType::I64 => Ok(sum.to_scalar::<i64>()? == 0),
1890 DType::F16 => Ok(sum.to_scalar::<half::f16>()? == half::f16::from_f32_const(0.)),
1891 DType::BF16 => Ok(sum.to_scalar::<half::bf16>()? == half::bf16::from_f32_const(0.)),
1892 DType::F32 => Ok(sum.to_scalar::<f32>()? == 0.),
1893 DType::F64 => Ok(sum.to_scalar::<f64>()? == 0.),
1894 DType::F8E4M3 => Ok(sum.to_scalar::<F8E4M3>()? == F8E4M3::ZERO),
1895 }
1896 }
1897}
1898
1899pub fn clamp_for_f16(xs: &Tensor) -> Result<Tensor> {
1900 let mut max = match xs.dtype() {
1901 DType::U8 => u8::MAX as f32 - 1000.,
1902 DType::U32 => u32::MAX as f32 - 1000.,
1903 DType::I16 => i16::MAX as f32 - 1000.,
1904 DType::I32 => i32::MAX as f32 - 1000.,
1905 DType::I64 => i64::MAX as f32 - 1000.,
1906 DType::F16 => half::f16::MAX.to_f32_const() - 1000.,
1907 DType::BF16 => half::bf16::MAX.to_f32_const() - 1000.,
1908 DType::F32 => f32::MAX - 1000.,
1909 DType::F64 => f64::MAX as f32 - 1000.,
1910 DType::F8E4M3 => F8E4M3::MAX.to_f32() - 1000.,
1911 };
1912 if xs.is_inf()?.any()? {
1913 max -= 1000.;
1914 }
1915 xs.clamp(-max, max)
1916}
1917
1918pub struct FloatInfo {
1919 pub min: f64,
1921 pub max: f64,
1923 pub eps: f64,
1925 pub dtype: DType,
1926}
1927
1928pub trait GetFloatInfo {
1929 fn finfo(&self) -> Result<FloatInfo>;
1930}
1931
1932impl GetFloatInfo for DType {
1933 fn finfo(&self) -> Result<FloatInfo> {
1934 let finfo = match self {
1935 Self::BF16 => FloatInfo {
1936 min: bf16::MIN.to_f64(),
1937 max: bf16::MAX.to_f64(),
1938 eps: bf16::EPSILON.to_f64(),
1939 dtype: DType::BF16,
1940 },
1941 Self::F16 => FloatInfo {
1942 min: f16::MIN.to_f64(),
1943 max: f16::MAX.to_f64(),
1944 eps: f16::EPSILON.to_f64(),
1945 dtype: DType::F16,
1946 },
1947 Self::F32 => FloatInfo {
1948 min: f32::MIN as f64,
1949 max: f32::MAX as f64,
1950 eps: f32::EPSILON as f64,
1951 dtype: DType::F32,
1952 },
1953 Self::F64 => FloatInfo {
1954 min: f64::MIN,
1955 max: f64::MAX,
1956 eps: f64::EPSILON,
1957 dtype: DType::F64,
1958 },
1959 Self::F8E4M3 => FloatInfo {
1960 min: F8E4M3::MIN.to_f64(),
1961 max: F8E4M3::MAX.to_f64(),
1962 eps: F8E4M3::EPSILON.to_f64(),
1963 dtype: DType::F8E4M3,
1964 },
1965 other => {
1966 candle_core::bail!("Expected a float type for `GetFloatInfo`, got {other:?}");
1967 }
1968 };
1969 Ok(finfo)
1970 }
1971}
1972
1973#[derive(Clone)]
1974pub struct Mlp {
1975 pub gate: Arc<dyn QuantMethod>,
1976 pub up: Arc<dyn QuantMethod>,
1977 pub down: Arc<dyn QuantMethod>,
1978 act: Activation,
1979 params: Vec<usize>,
1980}
1981
1982impl Mlp {
1983 pub fn new(
1984 vb: ShardedVarBuilder,
1985 hidden_size: usize,
1986 intermediate_size: usize,
1987 quantization_config: &Option<QuantizedConfig>,
1988 hidden_act: Activation,
1989 comm: &Arc<mistralrs_quant::Comm>,
1990 ) -> Result<Self> {
1991 Ok(Self {
1992 gate: ColumnParallelLayer::new(
1993 hidden_size,
1994 intermediate_size,
1995 quantization_config,
1996 false,
1997 comm,
1998 vb.pp("gate_proj"),
1999 )?,
2000 up: ColumnParallelLayer::new(
2001 hidden_size,
2002 intermediate_size,
2003 quantization_config,
2004 false,
2005 comm,
2006 vb.pp("up_proj"),
2007 )?,
2008 down: RowParallelLayer::new(
2009 intermediate_size,
2010 hidden_size,
2011 quantization_config,
2012 false,
2013 comm,
2014 vb.pp("down_proj"),
2015 )?,
2016 act: hidden_act,
2017 params: vec![hidden_size, intermediate_size],
2018 })
2019 }
2020
2021 pub fn replicate(
2022 params: &[usize],
2023 vb: ShardedVarBuilder,
2024 act: Activation,
2025 comm: &Arc<mistralrs_quant::Comm>,
2026 ) -> Result<Self> {
2027 Self::new(vb, params[0], params[1], &None, act, comm)
2028 }
2029
2030 pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2031 let original_dtype = xs.dtype();
2032 let mut xs = xs.clone();
2033 if let Some(t) = self.gate.quantized_act_type() {
2034 xs = xs.to_dtype(t)?;
2035 }
2036 let lhs = self.gate.forward(&xs)?;
2037 let rhs = self.up.forward(&xs)?;
2038 let mut res = self.down.forward(&candle_nn::ops::mul_and_act(
2039 &lhs,
2040 &rhs,
2041 self.act.try_into()?,
2042 )?)?;
2043 if self.gate.quantized_act_type().is_some() {
2044 res = res.to_dtype(original_dtype)?;
2045 }
2046 Ok(res)
2047 }
2048}
2049
2050impl AnyMoeTrainableLayer for Mlp {}
2051
2052impl MlpLayer for Mlp {
2053 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2054 let original_dtype = xs.dtype();
2055 let mut xs = xs.clone();
2056 if let Some(t) = self.gate.quantized_act_type() {
2057 xs = xs.to_dtype(t)?;
2058 }
2059 let lhs = MatMul.qmethod_matmul(&xs, &*self.gate)?;
2060 let rhs = MatMul.qmethod_matmul(&xs, &*self.up)?;
2061 let mut res = if matches!(
2062 self.act,
2063 Activation::Gelu | Activation::Silu | Activation::Relu
2064 ) {
2065 MatMul.qmethod_matmul(
2066 &candle_nn::ops::mul_and_act(&lhs, &rhs, self.act.try_into()?)?,
2067 &*self.down,
2068 )?
2069 } else {
2070 MatMul.qmethod_matmul(&(&lhs.apply(&self.act)? * &rhs)?, &*self.down)?
2071 };
2072 if self.gate.quantized_act_type().is_some() {
2073 res = res.to_dtype(original_dtype)?;
2074 }
2075 Ok(res)
2076 }
2077 fn get_isq_layers(&mut self) -> Vec<&mut Arc<dyn QuantMethod>> {
2078 vec![&mut self.gate, &mut self.up, &mut self.down]
2079 }
2080 fn clone(&self) -> Box<dyn MlpLayer> {
2081 Box::new(Clone::clone(self))
2082 }
2083 fn get_params(&self) -> &[usize] {
2084 &self.params
2085 }
2086 fn hidden_act(&self) -> Activation {
2087 self.act
2088 }
2089 fn new_added_delta(&self, deltas: Vec<Option<Tensor>>) -> Result<Box<dyn MlpLayer>> {
2091 let gate = if let Some(ref delta) = deltas[0] {
2092 self.gate.add_delta_w(delta)?
2093 } else {
2094 self.gate.clone()
2095 };
2096 let up = if let Some(ref delta) = deltas[1] {
2097 self.up.add_delta_w(delta)?
2098 } else {
2099 self.up.clone()
2100 };
2101 let down = if let Some(ref delta) = deltas[2] {
2102 self.down.add_delta_w(delta)?
2103 } else {
2104 self.down.clone()
2105 };
2106
2107 Ok(Box::new(Self {
2108 gate,
2109 up,
2110 down,
2111 act: self.act,
2112 params: self.params.clone(),
2113 }))
2114 }
2115
2116 fn dtype_device(&self) -> (DType, Device) {
2117 self.gate.dtype_and_device()
2118 }
2119}
2120
2121pub struct AvgPool2d {
2122 kernel_size: usize,
2123 stride: usize,
2124}
2125
2126impl AvgPool2d {
2127 pub fn new(kernel_size: usize, stride: usize) -> Self {
2128 Self {
2129 kernel_size,
2130 stride,
2131 }
2132 }
2133
2134 pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2135 xs.avg_pool2d_with_stride(self.kernel_size, self.stride)
2136 }
2137}
2138
2139pub struct ReflectionPad2d {
2146 padding: (usize, usize, usize, usize),
2147}
2148
2149impl ReflectionPad2d {
2150 pub fn new(padding: (usize, usize, usize, usize)) -> Self {
2151 Self { padding }
2152 }
2153}
2154
2155impl Module for ReflectionPad2d {
2156 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2157 let (pad_left, pad_right, pad_top, pad_bottom) = self.padding;
2158
2159 let (_n, _c, h, w) = xs.dims4()?;
2160
2161 let left_pad = if pad_left > 0 {
2164 let indices: Vec<i64> = (1..=pad_left as i64).rev().collect();
2166 Some(xs.index_select(&Tensor::new(indices, &Device::Cpu)?, 3)?)
2167 } else {
2168 None
2169 };
2170
2171 let right_pad = if pad_right > 0 {
2173 let start = w as i64 - 2;
2175 let indices: Vec<i64> = (0..pad_right as i64).map(|i| start - i).collect();
2176 Some(xs.index_select(&Tensor::new(indices, &Device::Cpu)?, 3)?)
2177 } else {
2178 None
2179 };
2180
2181 let x_padded_width = match (left_pad, right_pad) {
2183 (Some(l), Some(r)) => Tensor::cat(&[l, xs.clone(), r], 3)?,
2184 (Some(l), None) => Tensor::cat(&[l, xs.clone()], 3)?,
2185 (None, Some(r)) => Tensor::cat(&[xs.clone(), r], 3)?,
2186 (None, None) => xs.clone(),
2187 };
2188
2189 let top_pad = if pad_top > 0 {
2192 let indices: Vec<i64> = (1..=pad_top as i64).rev().collect();
2193 Some(x_padded_width.index_select(&Tensor::new(indices, &Device::Cpu)?, 2)?)
2194 } else {
2195 None
2196 };
2197
2198 let bottom_pad = if pad_bottom > 0 {
2200 let start = h as i64 - 2;
2201 let indices: Vec<i64> = (0..pad_bottom as i64).map(|i| start - i).collect();
2202 Some(x_padded_width.index_select(&Tensor::new(indices, &Device::Cpu)?, 2)?)
2203 } else {
2204 None
2205 };
2206
2207 let x_padded = match (top_pad, bottom_pad) {
2209 (Some(t), Some(b)) => Tensor::cat(&[t, x_padded_width, b], 2)?,
2210 (Some(t), None) => Tensor::cat(&[t, x_padded_width], 2)?,
2211 (None, Some(b)) => Tensor::cat(&[x_padded_width, b], 2)?,
2212 (None, None) => x_padded_width,
2213 };
2214
2215 Ok(x_padded)
2216 }
2217}
2218
2219pub struct ScaledEmbedding {
2220 scale: f64,
2221 embedding: Embedding,
2222}
2223
2224impl ScaledEmbedding {
2225 pub fn new(scale: f64, embedding: Embedding) -> Self {
2226 Self { scale, embedding }
2227 }
2228
2229 pub fn embeddings(&self) -> &Tensor {
2230 self.embedding.embeddings()
2231 }
2232}
2233
2234impl Module for ScaledEmbedding {
2235 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2236 xs.apply(&self.embedding)? * self.scale
2237 }
2238}