1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{f32::consts::PI, ops::Mul, str::FromStr, sync::Arc};
4
5use candle_core::{
6 quantized::{QMatMul, QTensor},
7 Context, DType, Device, IndexOp, Result, Tensor, D,
8};
9use candle_nn::{
10 BatchNorm, BatchNormConfig, Conv1d, Conv1dConfig, Conv2d, Conv2dConfig, Embedding, GroupNorm,
11 LayerNorm, LayerNormConfig, Linear, Module,
12};
13use float8::F8E4M3;
14use half::{bf16, f16};
15use mistralrs_quant::{
16 AfqLayer, ColumnParallelLayer, QuantMethod, QuantizedConfig, RowParallelLayer,
17 ShardedVarBuilder,
18};
19use serde::{Deserialize, Serialize};
20
21pub use crate::attention::Sdpa;
22pub use crate::layers_masker::CausalMasker;
23pub use crate::layers_utils::repeat_kv;
24use crate::{
25 amoe::{AnyMoeTrainableLayer, MlpLayer},
26 gguf::Content,
27 models::llama,
28 ops::SplitOp,
29 vision_models::{
30 gemma3::config::Gemma3TextConfig,
31 llama4,
32 mllama::{MLlamaRopeScaling, MLlamaRopeType, MLlamaTextConfig},
33 phi4::Phi4MMConfig,
34 },
35};
36
37pub use mistralrs_quant::MatMul;
38
39pub fn embedding(
40 in_size: usize,
41 out_size: usize,
42 vb: ShardedVarBuilder,
43 config: &Option<QuantizedConfig>,
44) -> Result<Embedding> {
45 let embeddings = if let Some(QuantizedConfig::Afq { .. }) = config {
47 let afq_layer =
48 AfqLayer::afq_linear_b(out_size, in_size, config.as_ref().unwrap(), false, vb)?;
49 afq_layer.dequantize_w()?
50 } else {
51 vb.get_with_hints((in_size, out_size), "weight", Default::default())?
52 };
53 Ok(Embedding::new(embeddings, out_size))
54}
55
56pub fn layer_norm<C: Into<LayerNormConfig>>(
57 size: usize,
58 config: C,
59 vb: ShardedVarBuilder,
60) -> Result<LayerNorm> {
61 let config = config.into();
62 let weight = vb.get(size, "weight")?;
63 if config.affine {
64 let bias = vb.get(size, "bias")?;
65 Ok(LayerNorm::new(weight, bias, config.eps))
66 } else {
67 Ok(LayerNorm::new_no_bias(weight, config.eps))
68 }
69}
70
71pub fn batch_norm<C: Into<BatchNormConfig>>(
72 num_features: usize,
73 config: C,
74 vb: ShardedVarBuilder,
75) -> Result<BatchNorm> {
76 let config = config.into();
77 if config.eps < 0. {
78 candle_core::bail!("batch-norm eps cannot be negative {}", config.eps)
79 }
80 let running_mean = vb.get(num_features, "running_mean")?;
81 let running_var = vb.get(num_features, "running_var")?;
82
83 if config.affine {
84 let weight = vb.get(num_features, "weight")?;
85 let bias = vb.get(num_features, "bias")?;
86 BatchNorm::new(
87 num_features,
88 running_mean,
89 running_var,
90 weight,
91 bias,
92 config.eps,
93 )
94 } else {
95 BatchNorm::new_no_bias(num_features, running_mean, running_var, config.eps)
96 }
97}
98
99pub fn group_norm(
100 num_groups: usize,
101 num_channels: usize,
102 eps: f64,
103 vb: ShardedVarBuilder,
104) -> Result<GroupNorm> {
105 let weight = vb.get(num_channels, "weight")?;
106 let bias = vb.get(num_channels, "bias")?;
107 GroupNorm::new(weight, bias, num_channels, num_groups, eps)
108}
109
110pub fn conv2d(
111 in_channels: usize,
112 out_channels: usize,
113 kernel_size: usize,
114 cfg: Conv2dConfig,
115 vb: ShardedVarBuilder,
116) -> Result<Conv2d> {
117 let ws = vb.get(
118 (
119 out_channels,
120 in_channels / cfg.groups,
121 kernel_size,
122 kernel_size,
123 ),
124 "weight",
125 )?;
126 let bs = vb.get(out_channels, "bias")?;
127 Ok(Conv2d::new(ws, Some(bs), cfg))
128}
129
130pub fn conv2d_no_bias(
131 in_channels: usize,
132 out_channels: usize,
133 kernel_size: usize,
134 cfg: Conv2dConfig,
135 vb: ShardedVarBuilder,
136) -> Result<Conv2d> {
137 let ws = vb.get(
138 (
139 out_channels,
140 in_channels / cfg.groups,
141 kernel_size,
142 kernel_size,
143 ),
144 "weight",
145 )?;
146 Ok(Conv2d::new(ws, None, cfg))
147}
148
149pub fn conv1d(
150 in_channels: usize,
151 out_channels: usize,
152 kernel_size: usize,
153 cfg: Conv1dConfig,
154 vb: ShardedVarBuilder,
155) -> Result<Conv1d> {
156 let ws = vb.get(
157 (out_channels, in_channels / cfg.groups, kernel_size),
158 "weight",
159 )?;
160 let bs = vb.get(out_channels, "bias")?;
161 Ok(Conv1d::new(ws, Some(bs), cfg))
162}
163
164pub fn conv1d_no_bias(
165 in_channels: usize,
166 out_channels: usize,
167 kernel_size: usize,
168 cfg: Conv1dConfig,
169 vb: ShardedVarBuilder,
170) -> Result<Conv1d> {
171 let ws = vb.get(
172 (out_channels, in_channels / cfg.groups, kernel_size),
173 "weight",
174 )?;
175 Ok(Conv1d::new(ws, None, cfg))
176}
177
178pub fn linear(in_dim: usize, out_dim: usize, vb: ShardedVarBuilder) -> Result<Linear> {
179 let ws = vb.get((out_dim, in_dim), "weight")?;
180 let bs = vb.get(out_dim, "bias")?;
181 Ok(Linear::new(ws, Some(bs)))
182}
183
184pub fn linear_no_bias(in_dim: usize, out_dim: usize, vb: ShardedVarBuilder) -> Result<Linear> {
185 let ws = vb.get((out_dim, in_dim), "weight")?;
186 Ok(Linear::new(ws, None))
187}
188
189pub fn linear_b(
190 in_dim: usize,
191 out_dim: usize,
192 bias: bool,
193 vb: ShardedVarBuilder,
194) -> Result<Linear> {
195 if bias {
196 linear(in_dim, out_dim, vb)
197 } else {
198 linear_no_bias(in_dim, out_dim, vb)
199 }
200}
201
202#[derive(Debug, Clone)]
203pub struct RmsNorm {
204 eps: f64,
205 weight: Tensor,
206}
207
208impl RmsNorm {
209 pub fn new(size: usize, eps: f64, vb: ShardedVarBuilder) -> Result<Self> {
210 let w = vb.get(size, "weight")?;
211 Ok(Self { eps, weight: w })
212 }
213
214 pub fn new_gemma(size: usize, eps: f64, vb: ShardedVarBuilder) -> Result<Self> {
216 let w = vb.get(size, "weight")?;
217 let w = (w + 1.0)?;
218 Ok(Self { eps, weight: w })
219 }
220
221 pub fn undo_gemma(&self) -> Result<Self> {
223 Ok(Self {
224 eps: self.eps,
225 weight: (&self.weight - 1.0)?,
226 })
227 }
228
229 pub fn from_w(w: Tensor, eps: f64) -> Result<Self> {
230 Ok(Self { eps, weight: w })
231 }
232
233 pub fn weight(&self) -> &Tensor {
234 &self.weight
235 }
236}
237
238impl Module for RmsNorm {
239 fn forward(&self, x: &Tensor) -> Result<Tensor> {
240 candle_nn::ops::rms_norm(&x.contiguous()?, &self.weight, self.eps as f32)
241 }
242}
243
244#[derive(Debug, Clone)]
245pub struct F32RmsNorm {
246 w: Tensor,
247 eps: f64,
248}
249
250impl F32RmsNorm {
251 pub fn new(size: usize, eps: f64, vb: ShardedVarBuilder) -> Result<Self> {
252 Ok(Self {
253 w: vb.get((size,), "weight")?,
254 eps,
255 })
256 }
257
258 pub fn weight(&self) -> &Tensor {
259 &self.w
260 }
261}
262
263impl Module for F32RmsNorm {
264 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
265 let initial_type = xs.dtype();
266 let mut xs = xs.to_dtype(DType::F32)?;
267 let var = xs.powf(2.)?.mean_keepdim(D::Minus1)?;
268 xs = xs.broadcast_mul(&(&var + self.eps)?.recip()?.sqrt()?)?;
269 xs.to_dtype(initial_type)?.broadcast_mul(&self.w)
270 }
271}
272
273#[derive(Debug, Clone)]
274pub struct QRmsNorm {
275 eps: f64,
276 weight: Tensor,
277}
278
279impl QRmsNorm {
280 pub fn new(scale: QTensor, eps: f32) -> Result<Self> {
281 let scale = scale.dequantize(&scale.device())?;
282 Ok(Self {
283 eps: eps as f64,
284 weight: scale,
285 })
286 }
287
288 pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
289 candle_nn::ops::rms_norm(&x.contiguous()?, &self.weight, self.eps as f32)
290 }
291}
292
293#[derive(Debug, Clone)]
295pub struct PhiRotaryEmbedding {
296 short_sin: Tensor,
297 short_cos: Tensor,
298 long_cos: Option<Tensor>,
299 long_sin: Option<Tensor>,
300 original_max_position_embeddings: usize,
301}
302
303#[derive(Debug, Clone, Deserialize, Serialize)]
304#[serde(rename_all = "lowercase")]
305pub enum ScaledRopeType {
306 #[serde(alias = "su")]
307 #[serde(alias = "longrope")]
308 Su,
309 #[serde(alias = "yarn")]
310 Yarn,
311 #[serde(alias = "dynamic")]
312 Dynamic,
313 #[serde(alias = "linear")]
314 Linear,
315}
316
317impl FromStr for ScaledRopeType {
318 type Err = candle_core::Error;
319 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
320 match s {
321 "su" | "longrope" => Ok(Self::Su),
322 "yarn" => Ok(Self::Yarn),
323 "linear" => Ok(Self::Linear),
324 "dynamic" => Ok(Self::Dynamic),
325 _ => Err(candle_core::Error::Msg(
326 "Expected either `su` or `yarn` scaled RoPE type.".to_string(),
327 )),
328 }
329 }
330}
331
332#[derive(Debug, Clone, Deserialize, Serialize)]
333#[serde(untagged)]
334pub enum PhiRopeScalingConfig {
335 Classic {
336 short_factor: Vec<f64>,
337 long_factor: Vec<f64>,
338 #[serde(rename = "type")]
339 scaling_type: ScaledRopeType,
340 },
341 Scaled {
342 short_factor: Vec<f64>,
343 long_factor: Vec<f64>,
344 #[serde(rename = "type")]
345 scaling_type: ScaledRopeType,
346 long_mscale: f64,
347 short_mscale: f64,
348 },
349}
350
351pub struct PhiRopeConfig {
352 pub rope_scaling: Option<PhiRopeScalingConfig>,
353 pub max_position_embeddings: usize,
354 pub original_max_position_embeddings: usize,
355 pub rope_theta: f64,
356 pub head_dim: usize,
357 pub partial_rotary_factor: Option<f64>,
358}
359
360impl PhiRotaryEmbedding {
361 fn new_classic_scaled(
362 short_factor: &[f64],
363 long_factor: &[f64],
364 scaling_type: &ScaledRopeType,
365 cfg: &PhiRopeConfig,
366 dtype: DType,
367 dev: &Device,
368 ) -> Result<Self> {
369 let max_seq_len = cfg.max_position_embeddings;
370 let dim = (cfg.head_dim as f64 * cfg.partial_rotary_factor.unwrap_or(1.)) as usize;
371
372 let scale =
374 cfg.max_position_embeddings as f64 / cfg.original_max_position_embeddings as f64;
375 let scaling_factor = if scale <= 1.0 {
376 1.0
377 } else {
378 match scaling_type {
379 ScaledRopeType::Su => {
380 (1.0 + scale.ln() / (cfg.original_max_position_embeddings as f64).ln()).sqrt()
381 }
382 ScaledRopeType::Yarn => 0.1 * scale.ln() + 1.0,
383 _ => candle_core::bail!("Expected either `su` or `yarn` RoPE"),
384 }
385 };
386
387 let inv_freq_long = (0..dim)
389 .step_by(2)
390 .enumerate()
391 .map(|(k, i)| {
392 (1f64 / (long_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64))) as f32
393 })
394 .collect::<Vec<_>>();
395 let inv_freq_short = (0..dim)
396 .step_by(2)
397 .enumerate()
398 .map(|(k, i)| {
399 (1f64 / (short_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64))) as f32
400 })
401 .collect::<Vec<_>>();
402 let inv_freq_len = inv_freq_long.len();
403
404 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
405 .to_dtype(DType::F32)?
406 .reshape((max_seq_len, 1))?;
407
408 let inv_freq_long = Tensor::from_vec(inv_freq_long, (1, inv_freq_len), dev)?;
410 let freqs_long = t.matmul(&inv_freq_long)?;
411 let long_sin = freqs_long.sin()?.mul(scaling_factor)?.to_dtype(dtype)?;
412 let long_cos = freqs_long.cos()?.mul(scaling_factor)?.to_dtype(dtype)?;
413
414 let inv_freq_short =
416 Tensor::from_vec(inv_freq_short, (1, inv_freq_len), dev)?.to_dtype(DType::F32)?;
417 let freqs_short = t.matmul(&inv_freq_short)?;
418 let short_sin = freqs_short.sin()?.mul(scaling_factor)?.to_dtype(dtype)?;
419 let short_cos = freqs_short.cos()?.mul(scaling_factor)?.to_dtype(dtype)?;
420
421 Ok(Self {
422 short_cos,
423 short_sin,
424 long_cos: Some(long_cos),
425 long_sin: Some(long_sin),
426 original_max_position_embeddings: cfg.original_max_position_embeddings,
427 })
428 }
429
430 fn new_unscaled(cfg: &PhiRopeConfig, dtype: DType, dev: &Device) -> Result<Self> {
431 let max_seq_len = cfg.max_position_embeddings;
432 let dim = (cfg.head_dim as f64 * cfg.partial_rotary_factor.unwrap_or(1.)) as usize;
433
434 let inv_freq: Vec<_> = (0..dim)
435 .step_by(2)
436 .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
437 .collect();
438 let inv_freq_len = inv_freq.len();
439 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
440 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
441 .to_dtype(DType::F32)?
442 .reshape((max_seq_len, 1))?;
443 let freqs = t.matmul(&inv_freq)?;
444 let sin = freqs.sin()?.to_dtype(dtype)?;
445 let cos = freqs.cos()?.to_dtype(dtype)?;
446 Ok(Self {
447 short_cos: cos,
448 short_sin: sin,
449 long_cos: None,
450 long_sin: None,
451 original_max_position_embeddings: cfg.original_max_position_embeddings,
452 })
453 }
454
455 #[allow(clippy::too_many_arguments)]
456 fn new_scaled(
457 short_factor: &[f64],
458 long_factor: &[f64],
459 scaling_type: &ScaledRopeType,
460 long_mscale: f64,
461 short_mscale: f64,
462 cfg: &PhiRopeConfig,
463 dtype: DType,
464 dev: &Device,
465 ) -> Result<Self> {
466 let max_seq_len = cfg.max_position_embeddings;
467 let dim = (cfg.head_dim as f64 * cfg.partial_rotary_factor.unwrap_or(1.)) as usize;
468
469 if !matches!(scaling_type, ScaledRopeType::Su) {
470 candle_core::bail!("Scaled Phi3 RoPE (non-classic scaled, with mscales) must have type `su`/`longrope`.");
471 }
472
473 if short_factor.len() != dim / 2 {
474 candle_core::bail!(
475 "Misaligned length {}, expected {} for `su`/`longrope` short rescale factors",
476 short_factor.len(),
477 dim / 2
478 );
479 }
480 if long_factor.len() != dim / 2 {
481 candle_core::bail!(
482 "Misaligned length {}, expected {} for `su`/`longrope` long rescale factors",
483 long_factor.len(),
484 dim / 2
485 );
486 }
487
488 let inv_freq_short: Vec<_> = (0..dim)
490 .step_by(2)
491 .enumerate()
492 .map(|(k, i)| {
493 1f32 / (short_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32
494 })
495 .collect();
496 let inv_freq_len_short = inv_freq_short.len();
497 let inv_freq_short = Tensor::from_vec(inv_freq_short, (1, inv_freq_len_short), dev)?;
498 let t_short = Tensor::arange(0u32, max_seq_len as u32, dev)?
499 .to_dtype(DType::F32)?
500 .reshape((max_seq_len, 1))?;
501 let freqs_short = t_short.matmul(&inv_freq_short)?;
502 let sin_short = (freqs_short.sin()?.to_dtype(dtype)? * short_mscale)?;
503 let cos_short = (freqs_short.cos()?.to_dtype(dtype)? * short_mscale)?;
504
505 let inv_freq_long: Vec<_> = (0..dim)
507 .step_by(2)
508 .enumerate()
509 .map(|(k, i)| {
510 1f32 / (long_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32
511 })
512 .collect();
513 let inv_freq_len_long = inv_freq_long.len();
514 let inv_freq_long = Tensor::from_vec(inv_freq_long, (1, inv_freq_len_long), dev)?;
515 let t_long = Tensor::arange(0u32, max_seq_len as u32, dev)?
516 .to_dtype(DType::F32)?
517 .reshape((max_seq_len, 1))?;
518 let freqs_long = t_long.matmul(&inv_freq_long)?;
519 let sin_long = (freqs_long.sin()?.to_dtype(dtype)? * long_mscale)?;
520 let cos_long = (freqs_long.cos()?.to_dtype(dtype)? * long_mscale)?;
521 Ok(Self {
522 short_cos: cos_short,
523 short_sin: sin_short,
524 long_cos: Some(cos_long),
525 long_sin: Some(sin_long),
526 original_max_position_embeddings: cfg.original_max_position_embeddings,
527 })
528 }
529
530 pub fn new(dtype: DType, cfg: impl Into<PhiRopeConfig>, dev: &Device) -> Result<Self> {
531 let cfg: PhiRopeConfig = cfg.into();
532
533 match &cfg.rope_scaling {
534 Some(PhiRopeScalingConfig::Classic {
535 short_factor,
536 long_factor,
537 scaling_type,
538 }) => {
539 Self::new_classic_scaled(short_factor, long_factor, scaling_type, &cfg, dtype, dev)
540 }
541
542 Some(PhiRopeScalingConfig::Scaled {
543 short_factor,
544 long_factor,
545 scaling_type,
546 long_mscale,
547 short_mscale,
548 }) => Self::new_scaled(
549 short_factor,
550 long_factor,
551 scaling_type,
552 *long_mscale,
553 *short_mscale,
554 &cfg,
555 dtype,
556 dev,
557 ),
558
559 None => Self::new_unscaled(&cfg, dtype, dev),
560 }
561 }
562
563 fn get_long_or_short_sin_cos(&self, position_ids: &[usize]) -> (&Tensor, &Tensor) {
565 if self.long_cos.is_none() {
566 return (&self.short_sin, &self.short_cos);
567 }
568 let seq_len = position_ids.iter().max().unwrap() + 1;
569 if seq_len > self.original_max_position_embeddings {
570 (
571 self.long_sin.as_ref().unwrap(),
572 self.long_cos.as_ref().unwrap(),
573 )
574 } else {
575 (&self.short_sin, &self.short_cos)
576 }
577 }
578
579 pub fn forward(
580 &self,
581 q: &Tensor,
582 k: &Tensor,
583 seqlen_offsets: &[usize],
584 position_ids: &[usize],
585 ) -> Result<(Tensor, Tensor)> {
586 let (sin, cos) = self.get_long_or_short_sin_cos(position_ids);
587 let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
588
589 let rot_dim = cos.dim(D::Minus1)? * 2;
590
591 if rot_dim != q.dim(D::Minus1)? {
593 let rot_dim = cos.dim(D::Minus1)? * 2;
594 let q_rot = q.narrow(D::Minus1, 0, rot_dim)?;
595 let q_pass = q.narrow(D::Minus1, rot_dim, q.dim(D::Minus1)? - rot_dim)?;
596 let k_rot = k.narrow(D::Minus1, 0, rot_dim)?;
597 let k_pass = k.narrow(D::Minus1, rot_dim, k.dim(D::Minus1)? - rot_dim)?;
598
599 let (q_rot, k_rot) = if seqlen_offsets.len() == 1 {
600 let cos = cos.narrow(0, seqlen_offsets[0], seq_len)?;
601 let sin = sin.narrow(0, seqlen_offsets[0], seq_len)?;
602 let q_embed = candle_nn::rotary_emb::rope(&q_rot.contiguous()?, &cos, &sin)?;
603 let k_embed = candle_nn::rotary_emb::rope(&k_rot.contiguous()?, &cos, &sin)?;
604 (q_embed, k_embed)
605 } else {
606 let mut q_embeds = Vec::new();
607 let mut k_embeds = Vec::new();
608 for (i, offset) in seqlen_offsets.iter().enumerate() {
609 let cos = cos.narrow(0, *offset, seq_len)?;
610 let sin = sin.narrow(0, *offset, seq_len)?;
611 let q_embed = candle_nn::rotary_emb::rope(
612 &q_rot.i(i)?.unsqueeze(0)?.contiguous()?,
613 &cos,
614 &sin,
615 )?;
616 let k_embed = candle_nn::rotary_emb::rope(
617 &k_rot.i(i)?.unsqueeze(0)?.contiguous()?,
618 &cos,
619 &sin,
620 )?;
621 q_embeds.push(q_embed);
622 k_embeds.push(k_embed);
623 }
624 let q_rot = Tensor::cat(&q_embeds, 0)?;
625 let k_rot = Tensor::cat(&k_embeds, 0)?;
626 (q_rot, k_rot)
627 };
628
629 Ok((
630 Tensor::cat(&[q_rot, q_pass], D::Minus1)?.contiguous()?,
631 Tensor::cat(&[k_rot, k_pass], D::Minus1)?.contiguous()?,
632 ))
633 } else if seqlen_offsets.len() == 1 {
634 let cos = cos.narrow(0, seqlen_offsets[0], seq_len)?;
635 let sin = sin.narrow(0, seqlen_offsets[0], seq_len)?;
636 let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
637 let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
638 Ok((q_embed, k_embed))
639 } else {
640 let mut q_embeds = Vec::new();
641 let mut k_embeds = Vec::new();
642 for (i, offset) in seqlen_offsets.iter().enumerate() {
643 let cos = cos.narrow(0, *offset, seq_len)?;
644 let sin = sin.narrow(0, *offset, seq_len)?;
645 let q_embed =
646 candle_nn::rotary_emb::rope(&q.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
647 let k_embed =
648 candle_nn::rotary_emb::rope(&k.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
649 q_embeds.push(q_embed);
650 k_embeds.push(k_embed);
651 }
652 Ok((Tensor::cat(&q_embeds, 0)?, Tensor::cat(&k_embeds, 0)?))
653 }
654 }
655}
656
657#[derive(Debug, Clone)]
659pub struct Llama3RotaryEmbedding(RotaryEmbedding);
660
661#[derive(Debug, Clone, Deserialize, Serialize, Default)]
662pub enum Llama3RopeType {
663 #[serde(rename = "llama3")]
664 Llama3,
665 #[serde(rename = "linear")]
666 Linear,
667 #[default]
668 #[serde(rename = "default")]
669 Default,
670}
671
672#[derive(Debug, Clone, Deserialize, Serialize, Default)]
673pub struct Llama3RopeConfig {
674 pub factor: f32,
675 pub low_freq_factor: Option<f32>,
676 pub high_freq_factor: Option<f32>,
677 pub original_max_position_embeddings: Option<usize>,
678 pub rope_type: Llama3RopeType,
679}
680
681fn calculate_default_inv_freq(cfg: &llama::Config) -> Vec<f32> {
682 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
683 (0..head_dim)
684 .step_by(2)
685 .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
686 .collect()
687}
688
689fn calculate_default_inv_freq_llama4(cfg: &llama4::TextConfig) -> Vec<f32> {
690 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
691 (0..head_dim)
692 .step_by(2)
693 .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
694 .collect()
695}
696
697impl Llama3RotaryEmbedding {
699 pub fn new_llama3(
700 dtype: DType,
701 cfg: &llama::Config,
702 dev: &Device,
703 is_gpt_neox: bool,
704 ) -> Result<Self> {
705 match &cfg.rope_scaling {
706 None
707 | Some(Llama3RopeConfig {
708 rope_type: Llama3RopeType::Default,
709 ..
710 }) => Ok(Self(RotaryEmbedding::new(
711 cfg.rope_theta,
712 cfg.hidden_size / cfg.num_attention_heads,
713 cfg.max_position_embeddings,
714 dev,
715 is_gpt_neox,
716 dtype,
717 )?)),
718 Some(Llama3RopeConfig {
719 rope_type: Llama3RopeType::Llama3,
720 factor,
721 low_freq_factor,
722 high_freq_factor,
723 original_max_position_embeddings,
724 }) => {
725 let low_freq_factor = low_freq_factor.context("low_freq_factor is required")?;
726 let high_freq_factor = high_freq_factor.context("high_freq_factor is required")?;
727 let original_max_position_embeddings = original_max_position_embeddings
728 .context("original_max_position_embeddings is required")?;
729
730 let low_freq_wavelen = original_max_position_embeddings as f32 / low_freq_factor;
731 let high_freq_wavelen = original_max_position_embeddings as f32 / high_freq_factor;
732
733 let inv_freq = calculate_default_inv_freq(cfg)
734 .into_iter()
735 .map(|freq| {
736 let wavelen = 2. * PI / freq;
737 if wavelen < high_freq_wavelen {
738 freq
739 } else if wavelen > low_freq_wavelen {
740 freq / *factor
741 } else {
742 let smooth = (original_max_position_embeddings as f32 / wavelen
743 - low_freq_factor)
744 / (high_freq_factor - low_freq_factor);
745 (1. - smooth) * freq / *factor + smooth * freq
746 }
747 })
748 .collect::<Vec<_>>();
749 let inv_freq_len = inv_freq.len();
750 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
751 let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
752 .to_dtype(DType::F32)?
753 .reshape((cfg.max_position_embeddings, 1))?;
754 let freqs = t.matmul(&inv_freq)?;
755 let sin = freqs.sin()?.to_dtype(dtype)?;
756 let cos = freqs.cos()?.to_dtype(dtype)?;
757 Ok(Self(RotaryEmbedding {
758 sin,
759 cos,
760 is_gpt_neox,
761 }))
762 }
763 Some(Llama3RopeConfig {
764 rope_type: Llama3RopeType::Linear,
765 factor,
766 ..
767 }) => {
768 let inv_freq_vec = calculate_default_inv_freq(cfg)
769 .into_iter()
770 .map(|freq| freq / *factor)
771 .collect::<Vec<_>>();
772 let inv_freq_len = inv_freq_vec.len();
773 let inv_freq = Tensor::from_vec(inv_freq_vec, (1, inv_freq_len), dev)?;
774 let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
775 .to_dtype(DType::F32)?
776 .reshape((cfg.max_position_embeddings, 1))?;
777 let freqs = t.matmul(&inv_freq)?;
778 let sin = freqs.sin()?.to_dtype(dtype)?;
779 let cos = freqs.cos()?.to_dtype(dtype)?;
780 Ok(Self(RotaryEmbedding {
781 sin,
782 cos,
783 is_gpt_neox,
784 }))
785 }
786 }
787 }
788
789 pub fn new_llama4(
790 dtype: DType,
791 cfg: &llama4::TextConfig,
792 dev: &Device,
793 is_gpt_neox: bool,
794 ) -> Result<Self> {
795 match &cfg.rope_scaling {
796 None
797 | Some(Llama3RopeConfig {
798 rope_type: Llama3RopeType::Default,
799 ..
800 }) => Ok(Self(RotaryEmbedding::new(
801 cfg.rope_theta,
802 cfg.hidden_size / cfg.num_attention_heads,
803 cfg.max_position_embeddings,
804 dev,
805 is_gpt_neox,
806 dtype,
807 )?)),
808 Some(Llama3RopeConfig {
809 rope_type: Llama3RopeType::Llama3,
810 factor,
811 low_freq_factor,
812 high_freq_factor,
813 original_max_position_embeddings,
814 }) => {
815 let low_freq_factor = low_freq_factor.context("low_freq_factor is required")?;
816 let high_freq_factor = high_freq_factor.context("high_freq_factor is required")?;
817 let original_max_position_embeddings = original_max_position_embeddings
818 .context("original_max_position_embeddings is required")?;
819
820 let low_freq_wavelen = original_max_position_embeddings as f32 / low_freq_factor;
821 let high_freq_wavelen = original_max_position_embeddings as f32 / high_freq_factor;
822
823 let inv_freq = calculate_default_inv_freq_llama4(cfg)
824 .into_iter()
825 .map(|freq| {
826 let wavelen = 2. * PI / freq;
827 if wavelen < high_freq_wavelen {
828 freq
829 } else if wavelen > low_freq_wavelen {
830 freq / *factor
831 } else {
832 let smooth = (original_max_position_embeddings as f32 / wavelen
833 - low_freq_factor)
834 / (high_freq_factor - low_freq_factor);
835 (1. - smooth) * freq / *factor + smooth * freq
836 }
837 })
838 .collect::<Vec<_>>();
839 let inv_freq_len = inv_freq.len();
840 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
841 let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
842 .to_dtype(DType::F32)?
843 .reshape((cfg.max_position_embeddings, 1))?;
844 let freqs = t.matmul(&inv_freq)?;
845 let sin = freqs.sin()?.to_dtype(dtype)?;
846 let cos = freqs.cos()?.to_dtype(dtype)?;
847 Ok(Self(RotaryEmbedding {
848 sin,
849 cos,
850 is_gpt_neox,
851 }))
852 }
853 Some(Llama3RopeConfig {
854 rope_type: Llama3RopeType::Linear,
855 factor,
856 ..
857 }) => {
858 let inv_freq_vec = calculate_default_inv_freq_llama4(cfg)
859 .into_iter()
860 .map(|freq| freq / *factor)
861 .collect::<Vec<_>>();
862 let inv_freq_len = inv_freq_vec.len();
863 let inv_freq = Tensor::from_vec(inv_freq_vec, (1, inv_freq_len), dev)?;
864 let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
865 .to_dtype(DType::F32)?
866 .reshape((cfg.max_position_embeddings, 1))?;
867 let freqs = t.matmul(&inv_freq)?;
868 let sin = freqs.sin()?.to_dtype(dtype)?;
869 let cos = freqs.cos()?.to_dtype(dtype)?;
870 Ok(Self(RotaryEmbedding {
871 sin,
872 cos,
873 is_gpt_neox,
874 }))
875 }
876 }
877 }
878
879 pub fn new_mllama3(
880 dtype: DType,
881 cfg: &MLlamaTextConfig,
882 dev: &Device,
883 is_gpt_neox: bool,
884 ) -> Result<Self> {
885 match &cfg.rope_scaling {
886 None
887 | Some(MLlamaRopeScaling {
888 rope_type: MLlamaRopeType::Default,
889 ..
890 }) => Ok(Self(RotaryEmbedding::new(
891 cfg.rope_theta,
892 cfg.hidden_size / cfg.num_attention_heads,
893 cfg.max_position_embeddings,
894 dev,
895 is_gpt_neox,
896 dtype,
897 )?)),
898 Some(MLlamaRopeScaling {
899 rope_type: MLlamaRopeType::Llama3,
900 original_max_position_embeddings,
901 factor,
902 attention_factor: _,
903 beta_fast: _,
904 beta_slow: _,
905 short_factor: _,
906 long_factor: _,
907 low_freq_factor,
908 high_freq_factor,
909 }) => {
910 let factor = factor.context("MLlama Llama3 RoPE needs `factor` parameter.")?;
911 let low_freq_factor = low_freq_factor
912 .context("MLlama Llama3 RoPE needs `low_freq_factor` parameter.")?;
913 let high_freq_factor = high_freq_factor
914 .context("MLlama Llama3 RoPE needs `high_freq_factor` parameter.")?;
915
916 let low_freq_wavelen = *original_max_position_embeddings as f32 / low_freq_factor;
917 let high_freq_wavelen = *original_max_position_embeddings as f32 / high_freq_factor;
918
919 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
920
921 let inv_freq = (0..head_dim)
922 .step_by(2)
923 .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
924 .map(|freq| {
925 let wavelen = 2. * PI / freq;
926 if wavelen < high_freq_wavelen {
927 freq
928 } else if wavelen > low_freq_wavelen {
929 freq / factor
930 } else {
931 let smooth = (*original_max_position_embeddings as f32 / wavelen
932 - low_freq_factor)
933 / (high_freq_factor - low_freq_factor);
934 (1. - smooth) * freq / factor + smooth * freq
935 }
936 })
937 .collect::<Vec<_>>();
938 let inv_freq_len = inv_freq.len();
939 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
940
941 let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
942 .to_dtype(DType::F32)?
943 .reshape((cfg.max_position_embeddings, 1))?;
944 let freqs = t.matmul(&inv_freq)?;
945 let sin = freqs.sin()?.to_dtype(dtype)?;
946 let cos = freqs.cos()?.to_dtype(dtype)?;
947 Ok(Self(RotaryEmbedding {
948 sin,
949 cos,
950 is_gpt_neox,
951 }))
952 }
953 Some(MLlamaRopeScaling {
954 rope_type: other, ..
955 }) => {
956 candle_core::bail!(
957 "MLlama doesn't support any other RoPE type than `llama3`, got {other:?}"
958 )
959 }
960 }
961 }
962
963 pub fn forward(
964 &self,
965 q: &Tensor,
966 k: &Tensor,
967 seqlen_offsets: &[usize],
968 ) -> Result<(Tensor, Tensor)> {
969 self.0.forward(q, k, seqlen_offsets)
970 }
971}
972
973#[derive(Debug, Clone)]
975pub struct Qwen2VLRotaryEmbedding {
976 inv_freq: Tensor,
977 mrope_section: Vec<usize>,
978}
979
980impl Qwen2VLRotaryEmbedding {
981 pub fn new(
982 base: f32,
983 head_dim: usize,
984 device: &Device,
985 mrope_section: Vec<usize>,
986 ) -> Result<Self> {
987 let inv_freq: Vec<_> = (0..head_dim)
988 .step_by(2)
989 .map(|i| 1f32 / base.powf(i as f32 / head_dim as f32))
990 .collect();
991 let inv_freq_len = inv_freq.len();
992 let inv_freq = Tensor::from_vec(inv_freq, (inv_freq_len,), device)?.to_dtype(DType::F32)?;
993 Ok(Self {
994 inv_freq,
995 mrope_section,
996 })
997 }
998
999 pub fn compute_cos_sin(&self, position_ids: &Tensor, dtype: DType) -> Result<(Tensor, Tensor)> {
1001 let inv_freq_expanded =
1002 self.inv_freq
1003 .reshape((1, 1, (), 1))?
1004 .repeat((3, position_ids.dim(1)?, 1, 1))?;
1005 let position_ids_expanded = position_ids.unsqueeze(2)?;
1006 let freqs = inv_freq_expanded
1007 .matmul(&position_ids_expanded.to_dtype(inv_freq_expanded.dtype())?)?
1008 .transpose(2, 3)?;
1009 let cos = freqs.cos()?;
1010 let sin = freqs.sin()?;
1011
1012 let cos = Tensor::cat(
1013 &cos.split(&self.mrope_section, D::Minus1)?
1014 .into_iter()
1015 .enumerate()
1016 .map(|(i, m)| m.i(i % 3))
1017 .collect::<Result<Vec<_>>>()?,
1018 D::Minus1,
1019 )?
1020 .squeeze(0)?
1021 .to_dtype(dtype)?
1022 .contiguous()?;
1023 let sin = Tensor::cat(
1024 &sin.split(&self.mrope_section, D::Minus1)?
1025 .into_iter()
1026 .enumerate()
1027 .map(|(i, m)| m.i(i % 3))
1028 .collect::<Result<Vec<_>>>()?,
1029 D::Minus1,
1030 )?
1031 .squeeze(0)?
1032 .to_dtype(dtype)?
1033 .contiguous()?;
1034
1035 Ok((cos, sin))
1036 }
1037
1038 pub fn forward(
1040 &self,
1041 (cos, sin): &(Tensor, Tensor),
1042 q: &mut Tensor,
1043 k: &mut Tensor,
1044 ) -> Result<()> {
1045 *q = candle_nn::rotary_emb::rope(&q.contiguous()?, cos, sin)?;
1046 *k = candle_nn::rotary_emb::rope(&k.contiguous()?, cos, sin)?;
1047 Ok(())
1048 }
1049}
1050
1051#[derive(Debug, Clone)]
1052pub struct Qwen2_5VLRotaryEmbedding {
1053 inv_freq: Tensor,
1054 mrope_section: Vec<usize>,
1055}
1056
1057impl Qwen2_5VLRotaryEmbedding {
1058 pub fn new(
1059 base: f32,
1060 head_dim: usize,
1061 device: &Device,
1062 mrope_section: Vec<usize>,
1063 ) -> Result<Self> {
1064 let inv_freq: Vec<_> = (0..head_dim)
1065 .step_by(2)
1066 .map(|i| 1f32 / base.powf(i as f32 / head_dim as f32))
1067 .collect();
1068 let inv_freq_len = inv_freq.len();
1069 let inv_freq = Tensor::from_vec(inv_freq, (inv_freq_len,), device)?.to_dtype(DType::F32)?;
1070 Ok(Self {
1071 inv_freq,
1072 mrope_section,
1073 })
1074 }
1075
1076 pub fn compute_cos_sin(&self, position_ids: &Tensor, dtype: DType) -> Result<(Tensor, Tensor)> {
1078 let inv_freq_expanded =
1079 self.inv_freq
1080 .reshape((1, 1, (), 1))?
1081 .repeat((3, position_ids.dim(1)?, 1, 1))?;
1082 let position_ids_expanded = position_ids.unsqueeze(2)?;
1083 let freqs = inv_freq_expanded
1084 .matmul(&position_ids_expanded.to_dtype(inv_freq_expanded.dtype())?)?
1085 .transpose(2, 3)?;
1086 let cos = freqs.cos()?;
1087 let sin = freqs.sin()?;
1088
1089 let cos = Tensor::cat(
1090 &cos.split(&self.mrope_section, D::Minus1)?
1091 .into_iter()
1092 .enumerate()
1093 .map(|(i, m)| m.i(i % 3))
1094 .collect::<Result<Vec<_>>>()?,
1095 D::Minus1,
1096 )?
1097 .squeeze(0)?
1098 .to_dtype(dtype)?
1099 .contiguous()?;
1100 let sin = Tensor::cat(
1101 &sin.split(&self.mrope_section, D::Minus1)?
1102 .into_iter()
1103 .enumerate()
1104 .map(|(i, m)| m.i(i % 3))
1105 .collect::<Result<Vec<_>>>()?,
1106 D::Minus1,
1107 )?
1108 .squeeze(0)?
1109 .to_dtype(dtype)?
1110 .contiguous()?;
1111
1112 Ok((cos, sin))
1113 }
1114
1115 pub fn forward(
1116 &self,
1117 (cos, sin): &(Tensor, Tensor),
1118 q: &mut Tensor,
1119 k: &mut Tensor,
1120 ) -> Result<()> {
1121 *q = candle_nn::rotary_emb::rope(&q.contiguous()?, cos, sin)?;
1122 *k = candle_nn::rotary_emb::rope(&k.contiguous()?, cos, sin)?;
1123 Ok(())
1124 }
1125}
1126
1127#[derive(Debug, Clone)]
1128pub struct DeepSeekV2RotaryEmbedding {
1129 sin: Tensor,
1130 cos: Tensor,
1131}
1132
1133#[derive(Debug, Clone, Deserialize, Serialize)]
1134#[serde(untagged)]
1135pub enum DeepSeekV2RopeScaling {
1136 Yarn {
1137 original_max_position_embeddings: usize,
1138 beta_fast: f32,
1139 beta_slow: f32,
1140 mscale: f32,
1141 mscale_all_dim: f32,
1142 factor: f32,
1143 #[serde(rename = "type")]
1144 scaling_type: ScaledRopeType,
1145 },
1146 LinearOrDynamic {
1147 #[serde(rename = "type")]
1148 scaling_type: ScaledRopeType,
1149 factor: f64,
1150 },
1151}
1152
1153pub struct DeepSeekV2RopeConfig {
1154 pub rope_scaling: Option<DeepSeekV2RopeScaling>,
1155 pub max_position_embeddings: usize,
1156 pub rope_theta: f32,
1157 pub qk_rope_head_dim: usize,
1158}
1159
1160impl DeepSeekV2RotaryEmbedding {
1161 fn new_unscaled(cfg: &DeepSeekV2RopeConfig, dtype: DType, dev: &Device) -> Result<Self> {
1162 let max_seq_len = cfg.max_position_embeddings;
1163 let dim = cfg.qk_rope_head_dim;
1164
1165 let inv_freq: Vec<_> = (0..dim)
1166 .step_by(2)
1167 .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / dim as f32))
1168 .collect();
1169 let inv_freq_len = inv_freq.len();
1170 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
1171 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
1172 .to_dtype(DType::F32)?
1173 .reshape((max_seq_len, 1))?;
1174 let freqs = t.matmul(&inv_freq)?;
1175
1176 let sin = freqs.sin()?.to_dtype(dtype)?;
1177 let cos = freqs.cos()?.to_dtype(dtype)?;
1178
1179 Ok(Self { sin, cos })
1180 }
1181
1182 fn yarn_find_correction_dim(
1183 num_rot: f32,
1184 dim: usize,
1185 base: f32,
1186 max_position_embeddings: usize,
1187 ) -> f32 {
1188 (dim as f32 * (max_position_embeddings as f32 / (num_rot * 2. * PI)).ln())
1189 / (2. * base.ln())
1190 }
1191
1192 fn yarn_find_correction_range(
1193 low_rot: f32,
1194 high_rot: f32,
1195 dim: usize,
1196 base: f32,
1197 max_position_embeddings: usize,
1198 ) -> (f32, f32) {
1199 let low =
1200 Self::yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings).floor();
1201 let high =
1202 Self::yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings).ceil();
1203 (low.max(0.), high.min(dim as f32 - 1.))
1204 }
1205
1206 fn yarn_linear_ramp_mask(min: f32, mut max: f32, dim: usize, dev: &Device) -> Result<Tensor> {
1207 if min == max {
1208 max += 0.001;
1210 }
1211 let linear_func =
1212 ((Tensor::arange(0f32, dim as f32, dev)? - min as f64)? / (max as f64 - min as f64))?;
1213 linear_func.clamp(0., 1)
1214 }
1215
1216 pub(crate) fn yarn_get_mscale(scale: f32, mscale: f32) -> f32 {
1217 if scale <= 1. {
1218 return 1.;
1219 }
1220 0.1 * mscale * scale.ln() + 1.
1221 }
1222
1223 #[allow(clippy::too_many_arguments)]
1224 fn new_yarn(
1225 cfg: &DeepSeekV2RopeConfig,
1226 dtype: DType,
1227 dev: &Device,
1228 original_max_position_embeddings: usize,
1229 beta_fast: f32,
1230 beta_slow: f32,
1231 factor: f32,
1232 mscale: f32,
1233 mscale_all_dim: f32,
1234 ) -> Result<Self> {
1235 let freq_extra: Vec<_> = (0..cfg.qk_rope_head_dim)
1236 .step_by(2)
1237 .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / cfg.qk_rope_head_dim as f32))
1238 .collect();
1239 let freq_extra_len = freq_extra.len();
1240 let freq_extra = Tensor::from_vec(freq_extra, freq_extra_len, dev)?;
1241 let freq_inter: Vec<_> = (0..cfg.qk_rope_head_dim)
1242 .step_by(2)
1243 .map(|i| 1f32 / (factor * cfg.rope_theta.powf(i as f32 / cfg.qk_rope_head_dim as f32)))
1244 .collect();
1245 let freq_inter_len = freq_inter.len();
1246 let freq_inter = Tensor::from_vec(freq_inter, (1, freq_inter_len), dev)?;
1247
1248 let (low, high) = Self::yarn_find_correction_range(
1249 beta_fast,
1250 beta_slow,
1251 cfg.qk_rope_head_dim,
1252 cfg.rope_theta,
1253 original_max_position_embeddings,
1254 );
1255 let inv_freq_mask =
1256 (1. - Self::yarn_linear_ramp_mask(low, high, cfg.qk_rope_head_dim / 2, dev)?)?;
1257 let inv_freq = freq_inter
1258 .broadcast_mul(&(1. - &inv_freq_mask)?)?
1259 .broadcast_add(&freq_extra.broadcast_mul(&inv_freq_mask)?)?;
1260
1261 let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
1262 .to_dtype(DType::F32)?
1263 .reshape((cfg.max_position_embeddings, 1))?;
1264 let freqs = t.matmul(&inv_freq)?;
1265
1266 let mscale =
1267 Self::yarn_get_mscale(factor, mscale) / Self::yarn_get_mscale(factor, mscale_all_dim);
1268 let sin = (freqs.sin()? * mscale as f64)?.to_dtype(dtype)?;
1269 let cos = (freqs.cos()? * mscale as f64)?.to_dtype(dtype)?;
1270
1271 Ok(Self { sin, cos })
1272 }
1273
1274 pub fn new(cfg: &DeepSeekV2RopeConfig, dtype: DType, dev: &Device) -> Result<Self> {
1275 match &cfg.rope_scaling {
1276 Some(DeepSeekV2RopeScaling::LinearOrDynamic {
1277 scaling_type: _,
1278 factor: _,
1279 }) => candle_core::bail!("linear and dynamic rope are not implemented yet!"),
1280 Some(DeepSeekV2RopeScaling::Yarn {
1281 original_max_position_embeddings,
1282 beta_fast,
1283 beta_slow,
1284 factor,
1285 mscale,
1286 mscale_all_dim,
1287 scaling_type: _,
1288 }) => Self::new_yarn(
1289 cfg,
1290 dtype,
1291 dev,
1292 *original_max_position_embeddings,
1293 *beta_fast,
1294 *beta_slow,
1295 *factor,
1296 *mscale,
1297 *mscale_all_dim,
1298 ),
1299 None => Self::new_unscaled(cfg, dtype, dev),
1300 }
1301 }
1302
1303 pub fn forward(
1304 &self,
1305 q: &Tensor,
1306 k: &Tensor,
1307 seqlen_offsets: &[usize],
1308 ) -> Result<(Tensor, Tensor)> {
1309 let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
1310
1311 if seqlen_offsets.len() == 1 {
1312 let cos = self.cos.narrow(0, seqlen_offsets[0], seq_len)?;
1313 let sin = self.sin.narrow(0, seqlen_offsets[0], seq_len)?;
1314 let q_embed = candle_nn::rotary_emb::rope_i(&q.contiguous()?, &cos, &sin)?;
1315 let k_embed = candle_nn::rotary_emb::rope_i(&k.contiguous()?, &cos, &sin)?;
1316 Ok((q_embed, k_embed))
1317 } else {
1318 let mut q_embeds = Vec::new();
1319 let mut k_embeds = Vec::new();
1320 for (i, offset) in seqlen_offsets.iter().enumerate() {
1321 let cos = self.cos.narrow(0, *offset, seq_len)?;
1322 let sin = self.sin.narrow(0, *offset, seq_len)?;
1323 let q_embed = candle_nn::rotary_emb::rope_i(
1324 &q.i(i)?.unsqueeze(0)?.contiguous()?,
1325 &cos,
1326 &sin,
1327 )?;
1328 let k_embed = candle_nn::rotary_emb::rope_i(
1329 &k.i(i)?.unsqueeze(0)?.contiguous()?,
1330 &cos,
1331 &sin,
1332 )?;
1333 q_embeds.push(q_embed);
1334 k_embeds.push(k_embed);
1335 }
1336 Ok((Tensor::cat(&q_embeds, 0)?, Tensor::cat(&k_embeds, 0)?))
1337 }
1338 }
1339}
1340
1341#[derive(Debug, Clone)]
1342pub struct Phi4MMRotaryEmbedding {
1343 short_sin: Tensor,
1344 short_cos: Tensor,
1345 long_cos: Option<Tensor>,
1346 long_sin: Option<Tensor>,
1347 original_max_position_embeddings: usize,
1348}
1349
1350#[derive(Debug, Clone, Default, Deserialize, Serialize)]
1351#[serde(rename_all = "lowercase")]
1352pub enum Phi4MMScaledRopeType {
1353 #[serde(alias = "longrope")]
1354 LongRope,
1355 #[default]
1356 Default,
1357}
1358
1359#[derive(Debug, Clone, Deserialize, Serialize)]
1360pub struct Phi4MMRopeScalingConfig {
1361 short_factor: Option<Vec<f64>>,
1362 long_factor: Option<Vec<f64>>,
1363 #[serde(rename = "type")]
1364 scaling_type: Phi4MMScaledRopeType,
1365}
1366
1367impl Phi4MMRotaryEmbedding {
1368 fn new_unscaled(cfg: &Phi4MMConfig, dtype: DType, dev: &Device) -> Result<Self> {
1369 let max_seq_len = cfg.max_position_embeddings;
1370 let dim = (cfg.head_dim() as f64 * cfg.partial_rotary_factor) as usize;
1371
1372 let inv_freq: Vec<_> = (0..dim)
1373 .step_by(2)
1374 .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
1375 .collect();
1376 let inv_freq_len = inv_freq.len();
1377 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
1378 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
1379 .to_dtype(DType::F32)?
1380 .reshape((max_seq_len, 1))?;
1381 let freqs = t.matmul(&inv_freq)?;
1382 let sin = freqs.sin()?.to_dtype(dtype)?;
1383 let cos = freqs.cos()?.to_dtype(dtype)?;
1384 Ok(Self {
1385 short_cos: cos,
1386 short_sin: sin,
1387 long_cos: None,
1388 long_sin: None,
1389 original_max_position_embeddings: cfg.original_max_position_embeddings,
1390 })
1391 }
1392
1393 #[allow(clippy::too_many_arguments)]
1394 fn new_longrope(
1395 short_factor: &[f64],
1396 long_factor: &[f64],
1397 cfg: &Phi4MMConfig,
1398 dtype: DType,
1399 dev: &Device,
1400 ) -> Result<Self> {
1401 let max_seq_len = cfg.max_position_embeddings;
1402 let dim = (cfg.head_dim() as f64 * cfg.partial_rotary_factor) as usize;
1403
1404 let scale =
1406 cfg.max_position_embeddings as f64 / cfg.original_max_position_embeddings as f64;
1407 let scaling_factor = if scale <= 1.0 {
1408 1.0
1409 } else {
1410 (1.0 + scale.ln() / (cfg.original_max_position_embeddings as f64).ln()).sqrt()
1411 };
1412
1413 let inv_freq_short: Vec<_> = (0..dim)
1415 .step_by(2)
1416 .enumerate()
1417 .map(|(k, i)| {
1418 1f32 / (short_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32
1419 })
1420 .collect();
1421 let inv_freq_len_short = inv_freq_short.len();
1422 let inv_freq_short = Tensor::from_vec(inv_freq_short, (1, inv_freq_len_short), dev)?;
1423 let t_short = Tensor::arange(0u32, max_seq_len as u32, dev)?
1424 .to_dtype(DType::F32)?
1425 .reshape((max_seq_len, 1))?;
1426 let freqs_short = t_short.matmul(&inv_freq_short)?;
1427 let sin_short = (freqs_short.sin()?.to_dtype(dtype)? * scaling_factor)?;
1428 let cos_short = (freqs_short.cos()?.to_dtype(dtype)? * scaling_factor)?;
1429
1430 let inv_freq_long: Vec<_> = (0..dim)
1432 .step_by(2)
1433 .enumerate()
1434 .map(|(k, i)| {
1435 1f32 / (long_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32
1436 })
1437 .collect();
1438 let inv_freq_len_long = inv_freq_long.len();
1439 let inv_freq_long = Tensor::from_vec(inv_freq_long, (1, inv_freq_len_long), dev)?;
1440 let t_long = Tensor::arange(0u32, max_seq_len as u32, dev)?
1441 .to_dtype(DType::F32)?
1442 .reshape((max_seq_len, 1))?;
1443 let freqs_long = t_long.matmul(&inv_freq_long)?;
1444 let sin_long = (freqs_long.sin()?.to_dtype(dtype)? * scaling_factor)?;
1445 let cos_long = (freqs_long.cos()?.to_dtype(dtype)? * scaling_factor)?;
1446
1447 Ok(Self {
1448 short_cos: cos_short,
1449 short_sin: sin_short,
1450 long_cos: Some(cos_long),
1451 long_sin: Some(sin_long),
1452 original_max_position_embeddings: cfg.original_max_position_embeddings,
1453 })
1454 }
1455
1456 pub fn new(dtype: DType, cfg: &Phi4MMConfig, dev: &Device) -> Result<Self> {
1457 match &cfg.rope_scaling {
1458 Some(Phi4MMRopeScalingConfig {
1459 scaling_type: Phi4MMScaledRopeType::LongRope,
1460 short_factor: Some(short_factor),
1461 long_factor: Some(long_factor),
1462 }) => Self::new_longrope(short_factor, long_factor, cfg, dtype, dev),
1463
1464 _ => Self::new_unscaled(cfg, dtype, dev),
1465 }
1466 }
1467
1468 fn get_long_or_short_sin_cos(&self, position_ids: &[usize]) -> (&Tensor, &Tensor) {
1470 if self.long_cos.is_none() {
1471 return (&self.short_sin, &self.short_cos);
1472 }
1473 let seq_len = position_ids.iter().max().unwrap() + 1;
1474 if seq_len > self.original_max_position_embeddings {
1475 (
1476 self.long_sin.as_ref().unwrap(),
1477 self.long_cos.as_ref().unwrap(),
1478 )
1479 } else {
1480 (&self.short_sin, &self.short_cos)
1481 }
1482 }
1483
1484 pub fn forward(
1485 &self,
1486 q: &Tensor,
1487 k: &Tensor,
1488 seqlen_offsets: &[usize],
1489 position_ids: &[usize],
1490 ) -> Result<(Tensor, Tensor)> {
1491 let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
1492 let (sin, cos) = self.get_long_or_short_sin_cos(position_ids);
1493
1494 let rot_dim = cos.dim(D::Minus1)? * 2;
1495 let q_rot = q.narrow(D::Minus1, 0, rot_dim)?;
1496 let q_pass = q.narrow(D::Minus1, rot_dim, q.dim(D::Minus1)? - rot_dim)?;
1497 let k_rot = k.narrow(D::Minus1, 0, rot_dim)?;
1498 let k_pass = k.narrow(D::Minus1, rot_dim, k.dim(D::Minus1)? - rot_dim)?;
1499
1500 let (q_rot, k_rot) = if seqlen_offsets.len() == 1 {
1501 let cos = cos.narrow(0, seqlen_offsets[0], seq_len)?;
1502 let sin = sin.narrow(0, seqlen_offsets[0], seq_len)?;
1503 let q_embed = candle_nn::rotary_emb::rope(&q_rot.contiguous()?, &cos, &sin)?;
1504 let k_embed = candle_nn::rotary_emb::rope(&k_rot.contiguous()?, &cos, &sin)?;
1505 (q_embed, k_embed)
1506 } else {
1507 let mut q_embeds = Vec::new();
1508 let mut k_embeds = Vec::new();
1509 for (i, offset) in seqlen_offsets.iter().enumerate() {
1510 let cos = cos.narrow(0, *offset, seq_len)?;
1511 let sin = sin.narrow(0, *offset, seq_len)?;
1512 let q_embed = candle_nn::rotary_emb::rope(
1513 &q_rot.i(i)?.unsqueeze(0)?.contiguous()?,
1514 &cos,
1515 &sin,
1516 )?;
1517 let k_embed = candle_nn::rotary_emb::rope(
1518 &k_rot.i(i)?.unsqueeze(0)?.contiguous()?,
1519 &cos,
1520 &sin,
1521 )?;
1522 q_embeds.push(q_embed);
1523 k_embeds.push(k_embed);
1524 }
1525 let q_rot = Tensor::cat(&q_embeds, 0)?;
1526 let k_rot = Tensor::cat(&k_embeds, 0)?;
1527 (q_rot, k_rot)
1528 };
1529
1530 Ok((
1531 Tensor::cat(&[q_rot, q_pass], D::Minus1)?.contiguous()?,
1532 Tensor::cat(&[k_rot, k_pass], D::Minus1)?.contiguous()?,
1533 ))
1534 }
1535}
1536
1537#[derive(Debug, Clone)]
1538pub struct Gemma3RotaryEmbedding(RotaryEmbedding);
1539
1540#[derive(Debug, Clone, Deserialize, Serialize)]
1541#[serde(rename_all = "lowercase")]
1542pub enum Gemma3ScaledRopeType {
1543 #[serde(alias = "linear")]
1544 Linear,
1545}
1546
1547#[derive(Debug, Clone, Deserialize, Serialize)]
1548pub struct Gemma3RopeScalingConfig {
1549 factor: f64,
1550 rope_type: Gemma3ScaledRopeType,
1551}
1552
1553impl Gemma3RotaryEmbedding {
1554 fn new_linear(
1555 cfg: &Gemma3TextConfig,
1556 factor: f64,
1557 is_gpt_neox: bool,
1558 dtype: DType,
1559 dev: &Device,
1560 ) -> Result<Self> {
1561 let max_seq_len = cfg.max_position_embeddings;
1562 let dim = cfg.head_dim;
1563
1564 let inv_freq: Vec<_> = (0..dim)
1565 .step_by(2)
1566 .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
1567 .collect();
1568 let inv_freq_len = inv_freq.len();
1569 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
1570 let inv_freq = (inv_freq / factor)?;
1571
1572 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
1573 .to_dtype(DType::F32)?
1574 .reshape((max_seq_len, 1))?;
1575 let freqs = t.matmul(&inv_freq)?;
1576 let sin = freqs.sin()?.to_dtype(dtype)?;
1577 let cos = freqs.cos()?.to_dtype(dtype)?;
1578 Ok(Self(RotaryEmbedding {
1579 cos,
1580 sin,
1581 is_gpt_neox,
1582 }))
1583 }
1584
1585 pub fn new(
1586 is_gpt_neox: bool,
1587 dtype: DType,
1588 cfg: &Gemma3TextConfig,
1589 dev: &Device,
1590 ) -> Result<Self> {
1591 match &cfg.rope_scaling {
1592 Some(Gemma3RopeScalingConfig {
1593 rope_type: Gemma3ScaledRopeType::Linear,
1594 factor,
1595 }) => Self::new_linear(cfg, *factor, is_gpt_neox, dtype, dev),
1596
1597 _ => Self::new_linear(cfg, 1.0, is_gpt_neox, dtype, dev),
1598 }
1599 }
1600
1601 pub fn forward(
1602 &self,
1603 q: &Tensor,
1604 k: &Tensor,
1605 seqlen_offsets: &[usize],
1606 ) -> Result<(Tensor, Tensor)> {
1607 self.0.forward(q, k, seqlen_offsets)
1608 }
1609}
1610
1611pub struct DiaRotaryEmbedding {
1612 timescale: Tensor,
1613 dtype: DType,
1614}
1615
1616impl DiaRotaryEmbedding {
1617 pub fn new(
1618 min_timescale: f32,
1619 max_timescale: f32,
1620 head_dim: usize,
1621 device: &Device,
1622 dtype: DType,
1623 ) -> Result<Self> {
1624 assert_eq!(head_dim % 2, 0);
1625 let half_embedding_dim = head_dim / 2;
1626
1627 let fraction = (0..half_embedding_dim).map(|i| 2f32 * i as f32 / head_dim as f32);
1628 let timescale = fraction
1629 .into_iter()
1630 .map(|x| min_timescale * (max_timescale / min_timescale).powf(x))
1631 .collect::<Vec<_>>();
1632
1633 let timescale_len = timescale.len();
1634 let timescale = Tensor::from_vec(timescale, timescale_len, device)?;
1635
1636 Ok(Self { timescale, dtype })
1637 }
1638
1639 pub fn forward(&self, xs: &Tensor, positions: &Tensor) -> Result<Tensor> {
1640 let freqs = positions
1641 .unsqueeze(D::Minus1)?
1642 .unsqueeze(D::Minus1)?
1643 .broadcast_div(&self.timescale)?;
1644
1645 let sin = freqs.sin()?.to_dtype(self.dtype)?;
1646 let cos = freqs.cos()?.to_dtype(self.dtype)?;
1647
1648 let split = xs.chunk(2, D::Minus1)?;
1649 let first_half = &split[0];
1650 let second_half = &split[1];
1651
1652 let first_part = (first_half.broadcast_mul(&cos)? - second_half.broadcast_mul(&sin)?)?;
1653 let second_part = (second_half.broadcast_mul(&cos)? + first_half.broadcast_mul(&sin)?)?;
1654
1655 Tensor::cat(&[first_part, second_part], D::Minus1)
1656 }
1657}
1658#[derive(Debug, Clone)]
1659pub struct QLinear {
1660 inner: QMatMul,
1661 bias: Option<Tensor>,
1662 dtype: DType,
1663}
1664
1665impl QLinear {
1666 pub fn new<R: std::io::Read + std::io::Seek>(
1667 ct: &mut Content<'_, R>,
1668 name: &str,
1669 device: &Device,
1670 ) -> Result<Self> {
1671 let w = ct.tensor(&format!("{name}.weight"), device)?;
1672 let b = ct.tensor(&format!("{name}.bias"), device)?;
1673 let inner = QMatMul::from_qtensor(w)?;
1674 let bias = b.dequantize(device)?;
1675 Ok(Self {
1676 inner,
1677 bias: Some(bias),
1678 dtype: DType::F32,
1679 })
1680 }
1681
1682 pub fn from_linear(linear: Linear) -> Self {
1683 Self {
1684 inner: QMatMul::Tensor(linear.weight().clone()),
1685 bias: linear.bias().cloned(),
1686 dtype: linear.weight().dtype(),
1687 }
1688 }
1689
1690 pub fn from_parts(w: Tensor, b: Option<Tensor>) -> Self {
1691 let dtype = w.dtype();
1692 Self {
1693 inner: QMatMul::Tensor(w),
1694 bias: b,
1695 dtype,
1696 }
1697 }
1698
1699 pub fn from_qparts(w: QTensor, b: Option<Tensor>) -> Self {
1700 if let Some(ref b) = b {
1701 assert_eq!(b.dtype(), DType::F32);
1702 }
1703 Self {
1704 inner: QMatMul::QTensor(Arc::new(w)),
1705 bias: b,
1706 dtype: DType::F32,
1707 }
1708 }
1709
1710 pub fn from_old_and_qmatmul(inner: QMatMul, old: &Self) -> Self {
1711 Self {
1712 inner,
1713 bias: old.bias.clone(),
1714 dtype: old.dtype,
1715 }
1716 }
1717
1718 pub fn inner(&mut self) -> &mut QMatMul {
1719 &mut self.inner
1720 }
1721
1722 pub fn inner_ref(&self) -> &QMatMul {
1723 &self.inner
1724 }
1725
1726 pub fn is_quant(&self) -> bool {
1727 matches!(self.inner, QMatMul::QTensor(_))
1728 }
1729
1730 pub fn bias(&self) -> Option<&Tensor> {
1731 self.bias.as_ref()
1732 }
1733
1734 pub fn bias_mut(&mut self) -> Option<&mut Tensor> {
1735 self.bias.as_mut()
1736 }
1737}
1738
1739impl Module for QLinear {
1740 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
1741 let xs = if self.is_quant() {
1742 xs.to_dtype(DType::F32)?
1743 } else {
1744 xs.clone()
1745 };
1746 if let Some(bias) = &self.bias {
1747 self.inner
1748 .forward(&xs)?
1749 .broadcast_add(bias)?
1750 .to_dtype(self.dtype)
1751 } else {
1752 self.inner.forward(&xs)?.to_dtype(self.dtype)
1753 }
1754 }
1755}
1756
1757#[derive(Debug, Clone)]
1758pub struct RotaryEmbedding {
1759 cos: Tensor,
1760 sin: Tensor,
1761 is_gpt_neox: bool,
1762}
1763
1764impl RotaryEmbedding {
1765 pub fn new(
1766 base: f32,
1767 head_dim: usize,
1768 max_position_embeddings: usize,
1769 device: &Device,
1770 is_gpt_neox: bool,
1771 dtype: DType,
1772 ) -> Result<Self> {
1773 let inv_freq: Vec<_> = (0..head_dim)
1774 .step_by(2)
1775 .map(|i| 1f32 / base.powf(i as f32 / head_dim as f32))
1776 .collect();
1777 let inv_freq_len = inv_freq.len();
1778 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), device)?;
1779 let t = Tensor::arange(0u32, max_position_embeddings as u32, device)?
1780 .to_dtype(DType::F32)?
1781 .reshape((max_position_embeddings, 1))?;
1782 let freqs = t.matmul(&inv_freq)?;
1783 let sin = freqs.sin()?.to_dtype(dtype)?;
1784 let cos = freqs.cos()?.to_dtype(dtype)?;
1785
1786 Ok(Self {
1787 cos,
1788 sin,
1789 is_gpt_neox,
1790 })
1791 }
1792
1793 pub fn new_partial(
1794 base: f32,
1795 rot_dim: usize,
1796 max_position_embeddings: usize,
1797 device: &Device,
1798 is_gpt_neox: bool,
1799 dtype: DType,
1800 ) -> Result<Self> {
1801 let inv_freq: Vec<_> = (0..rot_dim)
1802 .step_by(2)
1803 .map(|i| 1f32 / base.powf(i as f32 / rot_dim as f32))
1804 .collect();
1805 let inv_freq_len = inv_freq.len();
1806 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), device)?;
1807 let t = Tensor::arange(0u32, max_position_embeddings as u32, device)?
1808 .to_dtype(DType::F32)?
1809 .reshape((max_position_embeddings, 1))?;
1810 let freqs = t.matmul(&inv_freq)?;
1811 let sin = freqs.sin()?.to_dtype(dtype)?;
1812 let cos = freqs.cos()?.to_dtype(dtype)?;
1813
1814 Ok(Self {
1815 cos,
1816 sin,
1817 is_gpt_neox,
1818 })
1819 }
1820
1821 pub fn forward(
1822 &self,
1823 q: &Tensor,
1824 k: &Tensor,
1825 seqlen_offsets: &[usize],
1826 ) -> Result<(Tensor, Tensor)> {
1827 let (b_sz, qh, seq_len, n_embd) = q.dims4()?;
1828 let (_b_sz, kh, _seq_len, __n_embd) = k.dims4()?;
1829
1830 let rope = if self.is_gpt_neox {
1831 candle_nn::rotary_emb::rope
1832 } else {
1833 candle_nn::rotary_emb::rope_i
1834 };
1835
1836 if cfg!(feature = "cuda") && qh == kh {
1837 let (cos, sin) = if seqlen_offsets.len() == 1 {
1838 (
1839 self.cos.narrow(0, seqlen_offsets[0], seq_len)?,
1840 self.sin.narrow(0, seqlen_offsets[0], seq_len)?,
1841 )
1842 } else {
1843 let mut cos_s = Vec::new();
1844 let mut sin_s = Vec::new();
1845 for offset in seqlen_offsets {
1846 cos_s.push(self.cos.narrow(0, *offset, seq_len)?);
1847 sin_s.push(self.sin.narrow(0, *offset, seq_len)?);
1848 }
1849 (Tensor::cat(&cos_s, 0)?, Tensor::cat(&sin_s, 0)?)
1850 };
1851
1852 let q_embed = q.transpose(1, 2)?.flatten(0, 1)?;
1853 let k_embed = k.transpose(1, 2)?.flatten(0, 1)?;
1854 mistralrs_quant::rotary::apply_rotary_inplace(
1855 &q_embed,
1856 &k_embed,
1857 &cos,
1858 &sin,
1859 self.is_gpt_neox,
1860 )?;
1861 let mut q = q_embed
1862 .reshape((b_sz, seq_len, qh, n_embd))?
1863 .transpose(1, 2)?;
1864 let mut k = k_embed
1865 .reshape((b_sz, seq_len, kh, n_embd))?
1866 .transpose(1, 2)?;
1867 if !(cfg!(feature = "flash-attn") || cfg!(feature = "flash-attn-v3")) {
1868 q = q.contiguous()?;
1869 k = k.contiguous()?;
1870 }
1871 Ok((q, k))
1872 } else if seqlen_offsets.len() == 1 {
1873 let cos = self.cos.narrow(0, seqlen_offsets[0], seq_len)?;
1874 let sin = self.sin.narrow(0, seqlen_offsets[0], seq_len)?;
1875 let q_embed = rope(&q.contiguous()?, &cos, &sin)?;
1876 let k_embed = rope(&k.contiguous()?, &cos, &sin)?;
1877 Ok((q_embed, k_embed))
1878 } else {
1879 let mut q_embeds = Vec::new();
1880 let mut k_embeds = Vec::new();
1881 for (i, offset) in seqlen_offsets.iter().enumerate() {
1882 let cos = self.cos.narrow(0, *offset, seq_len)?;
1883 let sin = self.sin.narrow(0, *offset, seq_len)?;
1884 let q_embed = rope(&q.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
1885 let k_embed = rope(&k.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
1886 q_embeds.push(q_embed);
1887 k_embeds.push(k_embed);
1888 }
1889 Ok((Tensor::cat(&q_embeds, 0)?, Tensor::cat(&k_embeds, 0)?))
1890 }
1891 }
1892}
1893
1894#[derive(Debug, Clone, Copy, PartialEq, Deserialize, Serialize, Default)]
1895#[serde(rename_all = "lowercase")]
1896pub enum Activation {
1897 #[default]
1898 #[serde(alias = "gelu")]
1899 Gelu,
1900 #[serde(alias = "gelu_new")]
1901 NewGelu,
1902 Relu,
1903 Relu2,
1904 Relu6,
1905 Silu,
1906 Sigmoid,
1907 HardSigmoid,
1908 Swiglu,
1909 Swish,
1910 HardSwish,
1911 Elu(f64),
1912 LeakyRelu(f64),
1913 #[serde(alias = "gelu_pytorch_tanh")]
1914 GeluPytorchTanh,
1915 QuickGelu,
1916}
1917
1918impl Module for Activation {
1919 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
1920 match self {
1921 Self::Gelu => xs.gelu_erf(),
1922 Self::NewGelu => xs.gelu(),
1924 Self::Relu => xs.relu(),
1925 Self::Relu2 => xs.relu()?.sqr(),
1926 Self::Relu6 => xs.clamp(0f32, 6f32),
1927 Self::Silu => xs.silu(),
1928 Self::Sigmoid => candle_nn::ops::sigmoid(xs),
1929 Self::HardSigmoid => candle_nn::ops::hard_sigmoid(xs),
1930 Self::Swiglu => candle_nn::ops::swiglu(xs),
1931 Self::Swish => xs * candle_nn::ops::sigmoid(xs)?,
1932 Self::HardSwish => xs * candle_nn::ops::hard_sigmoid(xs)?,
1933 &Self::Elu(alpha) => xs.elu(alpha),
1934 &Self::LeakyRelu(negative_slope) => candle_nn::ops::leaky_relu(xs, negative_slope),
1935 Self::GeluPytorchTanh => xs.gelu(),
1936 Self::QuickGelu => xs * candle_nn::ops::sigmoid(&(xs * 1.702f64)?),
1937 }
1938 }
1939}
1940
1941impl TryInto<candle_nn::Activation> for Activation {
1942 type Error = candle_core::Error;
1943
1944 fn try_into(self) -> Result<candle_nn::Activation> {
1945 match self {
1946 Self::Gelu => Ok(candle_nn::Activation::Gelu),
1947 Self::Relu => Ok(candle_nn::Activation::Relu),
1948 Self::Silu => Ok(candle_nn::Activation::Silu),
1949 Self::NewGelu => Ok(candle_nn::Activation::NewGelu),
1950 Self::Relu2 => Ok(candle_nn::Activation::Relu2),
1951 Self::Relu6 => Ok(candle_nn::Activation::Relu6),
1952 Self::Sigmoid => Ok(candle_nn::Activation::Sigmoid),
1953 Self::HardSigmoid => Ok(candle_nn::Activation::HardSigmoid),
1954 Self::Swiglu => Ok(candle_nn::Activation::Swiglu),
1955 Self::Swish => Ok(candle_nn::Activation::Swish),
1956 Self::HardSwish => Ok(candle_nn::Activation::HardSwish),
1957 Self::Elu(x) => Ok(candle_nn::Activation::Elu(x)),
1958 Self::LeakyRelu(x) => Ok(candle_nn::Activation::LeakyRelu(x)),
1959 Self::GeluPytorchTanh => Ok(candle_nn::Activation::GeluPytorchTanh),
1960 Self::QuickGelu => candle_core::bail!("No mapping to candle_nn for QuickGelu"),
1961 }
1962 }
1963}
1964
1965#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1966pub struct Conv3dConfig {
1967 pub padding: usize,
1968 pub stride: usize,
1969 pub dilation: usize,
1970 pub groups: usize,
1971}
1972
1973impl Default for Conv3dConfig {
1974 fn default() -> Self {
1975 Self {
1976 padding: 0,
1977 stride: 1,
1978 dilation: 1,
1979 groups: 1,
1980 }
1981 }
1982}
1983
1984pub struct Conv3dNoBias {
1985 conv2d_1: Conv2d,
1986 conv2d_2: Conv2d,
1987}
1988
1989impl Conv3dNoBias {
1990 pub fn new(
1991 in_channels: usize,
1992 out_channels: usize,
1993 kernel_sizes: [usize; 3],
1994 cfg: Conv3dConfig,
1995 vb: ShardedVarBuilder,
1996 ) -> Result<Self> {
1997 let ws = vb.get(
1998 (
1999 out_channels,
2000 in_channels / cfg.groups,
2001 kernel_sizes[0],
2002 kernel_sizes[1],
2003 kernel_sizes[2],
2004 ),
2005 "weight",
2006 )?;
2007
2008 let w1 = ws.i((.., .., 0, .., ..))?;
2012 let w2 = ws.i((.., .., 1, .., ..))?;
2013
2014 let cfg = Conv2dConfig {
2015 padding: cfg.padding,
2016 stride: cfg.stride,
2017 dilation: cfg.dilation,
2018 groups: cfg.groups,
2019 };
2020
2021 Ok(Self {
2022 conv2d_1: Conv2d::new(w1.contiguous()?, None, cfg),
2023 conv2d_2: Conv2d::new(w2.contiguous()?, None, cfg),
2024 })
2025 }
2026}
2027
2028impl Module for Conv3dNoBias {
2029 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2030 let xs1 = xs.i((.., .., 0, .., ..))?;
2031 let xs2 = xs.i((.., .., 1, .., ..))?;
2032
2033 (self.conv2d_1.forward(&xs1)? + self.conv2d_2.forward(&xs2)?)?.unsqueeze(2)
2034 }
2035}
2036
2037pub trait TensorInfExtend {
2038 fn is_inf(&self) -> Result<Self>
2039 where
2040 Self: Sized;
2041 fn any(&self) -> Result<bool>;
2042}
2043
2044impl TensorInfExtend for Tensor {
2045 fn is_inf(&self) -> Result<Self> {
2046 self.broadcast_eq(&Tensor::new(f32::INFINITY, self.device())?.to_dtype(self.dtype())?)
2047 }
2048
2049 fn any(&self) -> Result<bool> {
2050 let sum = self.sum_all()?;
2051 match self.dtype() {
2052 DType::U8 => Ok(sum.to_scalar::<u8>()? == 0),
2053 DType::U32 => Ok(sum.to_scalar::<u32>()? == 0),
2054 DType::I16 => Ok(sum.to_scalar::<i16>()? == 0),
2055 DType::I32 => Ok(sum.to_scalar::<i32>()? == 0),
2056 DType::I64 => Ok(sum.to_scalar::<i64>()? == 0),
2057 DType::F16 => Ok(sum.to_scalar::<half::f16>()? == half::f16::from_f32_const(0.)),
2058 DType::BF16 => Ok(sum.to_scalar::<half::bf16>()? == half::bf16::from_f32_const(0.)),
2059 DType::F32 => Ok(sum.to_scalar::<f32>()? == 0.),
2060 DType::F64 => Ok(sum.to_scalar::<f64>()? == 0.),
2061 DType::F8E4M3 => Ok(sum.to_scalar::<F8E4M3>()? == F8E4M3::ZERO),
2062 }
2063 }
2064}
2065
2066pub fn clamp_for_f16(xs: &Tensor) -> Result<Tensor> {
2067 let mut max = match xs.dtype() {
2068 DType::U8 => u8::MAX as f32 - 1000.,
2069 DType::U32 => u32::MAX as f32 - 1000.,
2070 DType::I16 => i16::MAX as f32 - 1000.,
2071 DType::I32 => i32::MAX as f32 - 1000.,
2072 DType::I64 => i64::MAX as f32 - 1000.,
2073 DType::F16 => half::f16::MAX.to_f32_const() - 1000.,
2074 DType::BF16 => half::bf16::MAX.to_f32_const() - 1000.,
2075 DType::F32 => f32::MAX - 1000.,
2076 DType::F64 => f64::MAX as f32 - 1000.,
2077 DType::F8E4M3 => F8E4M3::MAX.to_f32() - 1000.,
2078 };
2079 if xs.is_inf()?.any()? {
2080 max -= 1000.;
2081 }
2082 xs.clamp(-max, max)
2083}
2084
2085pub struct FloatInfo {
2086 pub min: f64,
2088 pub max: f64,
2090 pub eps: f64,
2092 pub dtype: DType,
2093}
2094
2095pub trait GetFloatInfo {
2096 fn finfo(&self) -> Result<FloatInfo>;
2097}
2098
2099impl GetFloatInfo for DType {
2100 fn finfo(&self) -> Result<FloatInfo> {
2101 let finfo = match self {
2102 Self::BF16 => FloatInfo {
2103 min: bf16::MIN.to_f64(),
2104 max: bf16::MAX.to_f64(),
2105 eps: bf16::EPSILON.to_f64(),
2106 dtype: DType::BF16,
2107 },
2108 Self::F16 => FloatInfo {
2109 min: f16::MIN.to_f64(),
2110 max: f16::MAX.to_f64(),
2111 eps: f16::EPSILON.to_f64(),
2112 dtype: DType::F16,
2113 },
2114 Self::F32 => FloatInfo {
2115 min: f32::MIN as f64,
2116 max: f32::MAX as f64,
2117 eps: f32::EPSILON as f64,
2118 dtype: DType::F32,
2119 },
2120 Self::F64 => FloatInfo {
2121 min: f64::MIN,
2122 max: f64::MAX,
2123 eps: f64::EPSILON,
2124 dtype: DType::F64,
2125 },
2126 Self::F8E4M3 => FloatInfo {
2127 min: F8E4M3::MIN.to_f64(),
2128 max: F8E4M3::MAX.to_f64(),
2129 eps: F8E4M3::EPSILON.to_f64(),
2130 dtype: DType::F8E4M3,
2131 },
2132 other => {
2133 candle_core::bail!("Expected a float type for `GetFloatInfo`, got {other:?}");
2134 }
2135 };
2136 Ok(finfo)
2137 }
2138}
2139
2140#[derive(Clone)]
2141pub struct Mlp {
2142 pub gate: Arc<dyn QuantMethod>,
2143 pub up: Arc<dyn QuantMethod>,
2144 pub down: Arc<dyn QuantMethod>,
2145 act: Activation,
2146 params: Vec<usize>,
2147}
2148
2149impl Mlp {
2150 pub fn new(
2151 vb: ShardedVarBuilder,
2152 hidden_size: usize,
2153 intermediate_size: usize,
2154 quantization_config: &Option<QuantizedConfig>,
2155 hidden_act: Activation,
2156 comm: &Arc<mistralrs_quant::Comm>,
2157 ) -> Result<Self> {
2158 Ok(Self {
2159 gate: ColumnParallelLayer::new(
2160 hidden_size,
2161 intermediate_size,
2162 quantization_config,
2163 false,
2164 comm,
2165 vb.pp("gate_proj"),
2166 )?,
2167 up: ColumnParallelLayer::new(
2168 hidden_size,
2169 intermediate_size,
2170 quantization_config,
2171 false,
2172 comm,
2173 vb.pp("up_proj"),
2174 )?,
2175 down: RowParallelLayer::new(
2176 intermediate_size,
2177 hidden_size,
2178 quantization_config,
2179 false,
2180 comm,
2181 vb.pp("down_proj"),
2182 )?,
2183 act: hidden_act,
2184 params: vec![hidden_size, intermediate_size],
2185 })
2186 }
2187
2188 pub fn new_merged(
2189 vb: ShardedVarBuilder,
2190 hidden_size: usize,
2191 intermediate_size: usize,
2192 chunks: usize,
2193 quantization_config: &Option<QuantizedConfig>,
2194 hidden_act: Activation,
2195 comm: &Arc<mistralrs_quant::Comm>,
2196 ) -> Result<Self> {
2197 assert!(chunks == 2, "Only gate_up_proj merge is supported!");
2198 let gate_up_projs = ColumnParallelLayer::new_merged(
2199 hidden_size,
2200 intermediate_size * 2,
2201 2,
2202 quantization_config,
2203 false,
2204 comm,
2205 vb.pp("gate_up_proj"),
2206 )?;
2207
2208 Ok(Self {
2209 gate: gate_up_projs[0].to_owned(),
2210 up: gate_up_projs[1].to_owned(),
2211 down: RowParallelLayer::new(
2212 intermediate_size,
2213 hidden_size,
2214 quantization_config,
2215 false,
2216 comm,
2217 vb.pp("down_proj"),
2218 )?,
2219 act: hidden_act,
2220 params: vec![hidden_size, intermediate_size],
2221 })
2222 }
2223
2224 pub fn replicate(
2225 params: &[usize],
2226 vb: ShardedVarBuilder,
2227 act: Activation,
2228 comm: &Arc<mistralrs_quant::Comm>,
2229 ) -> Result<Self> {
2230 Self::new(vb, params[0], params[1], &None, act, comm)
2231 }
2232
2233 pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2234 let original_dtype = xs.dtype();
2235 let mut xs = xs.clone();
2236 if let Some(t) = self.gate.quantized_act_type() {
2237 xs = xs.to_dtype(t)?;
2238 }
2239 let lhs = self.gate.forward(&xs)?;
2240 let rhs = self.up.forward(&xs)?;
2241 let mut res = self.down.forward(&candle_nn::ops::mul_and_act(
2242 &lhs,
2243 &rhs,
2244 self.act.try_into()?,
2245 )?)?;
2246 if self.gate.quantized_act_type().is_some() {
2247 res = res.to_dtype(original_dtype)?;
2248 }
2249 Ok(res)
2250 }
2251}
2252
2253impl AnyMoeTrainableLayer for Mlp {}
2254
2255impl MlpLayer for Mlp {
2256 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2257 let original_dtype = xs.dtype();
2258 let mut xs = xs.clone();
2259 if let Some(t) = self.gate.quantized_act_type() {
2260 xs = xs.to_dtype(t)?;
2261 }
2262 let lhs = MatMul.qmethod_matmul(&xs, &*self.gate)?;
2263 let rhs = MatMul.qmethod_matmul(&xs, &*self.up)?;
2264 let mut res = if matches!(
2265 self.act,
2266 Activation::Gelu | Activation::Silu | Activation::Relu
2267 ) {
2268 MatMul.qmethod_matmul(
2269 &candle_nn::ops::mul_and_act(&lhs, &rhs, self.act.try_into()?)?,
2270 &*self.down,
2271 )?
2272 } else {
2273 MatMul.qmethod_matmul(&(&lhs.apply(&self.act)? * &rhs)?, &*self.down)?
2274 };
2275 if self.gate.quantized_act_type().is_some() {
2276 res = res.to_dtype(original_dtype)?;
2277 }
2278 Ok(res)
2279 }
2280 fn get_isq_layers(&mut self) -> Vec<&mut Arc<dyn QuantMethod>> {
2281 vec![&mut self.gate, &mut self.up, &mut self.down]
2282 }
2283 fn clone(&self) -> Box<dyn MlpLayer> {
2284 Box::new(Clone::clone(self))
2285 }
2286 fn get_params(&self) -> &[usize] {
2287 &self.params
2288 }
2289 fn hidden_act(&self) -> Activation {
2290 self.act
2291 }
2292 fn new_added_delta(&self, deltas: Vec<Option<Tensor>>) -> Result<Box<dyn MlpLayer>> {
2294 let gate = if let Some(ref delta) = deltas[0] {
2295 self.gate.add_delta_w(delta)?
2296 } else {
2297 self.gate.clone()
2298 };
2299 let up = if let Some(ref delta) = deltas[1] {
2300 self.up.add_delta_w(delta)?
2301 } else {
2302 self.up.clone()
2303 };
2304 let down = if let Some(ref delta) = deltas[2] {
2305 self.down.add_delta_w(delta)?
2306 } else {
2307 self.down.clone()
2308 };
2309
2310 Ok(Box::new(Self {
2311 gate,
2312 up,
2313 down,
2314 act: self.act,
2315 params: self.params.clone(),
2316 }))
2317 }
2318
2319 fn dtype_device(&self) -> (DType, Device) {
2320 self.gate.dtype_and_device()
2321 }
2322}
2323
2324pub struct AvgPool2d {
2325 kernel_size: usize,
2326 stride: usize,
2327}
2328
2329impl AvgPool2d {
2330 pub fn new(kernel_size: usize, stride: usize) -> Self {
2331 Self {
2332 kernel_size,
2333 stride,
2334 }
2335 }
2336
2337 pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2338 xs.avg_pool2d_with_stride(self.kernel_size, self.stride)
2339 }
2340}
2341
2342pub struct ReflectionPad2d {
2349 padding: (usize, usize, usize, usize),
2350}
2351
2352impl ReflectionPad2d {
2353 pub fn new(padding: (usize, usize, usize, usize)) -> Self {
2354 Self { padding }
2355 }
2356}
2357
2358impl Module for ReflectionPad2d {
2359 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2360 let (pad_left, pad_right, pad_top, pad_bottom) = self.padding;
2361
2362 let (_n, _c, h, w) = xs.dims4()?;
2363
2364 let left_pad = if pad_left > 0 {
2367 let indices: Vec<i64> = (1..=pad_left as i64).rev().collect();
2369 Some(xs.index_select(&Tensor::new(indices, &Device::Cpu)?, 3)?)
2370 } else {
2371 None
2372 };
2373
2374 let right_pad = if pad_right > 0 {
2376 let start = w as i64 - 2;
2378 let indices: Vec<i64> = (0..pad_right as i64).map(|i| start - i).collect();
2379 Some(xs.index_select(&Tensor::new(indices, &Device::Cpu)?, 3)?)
2380 } else {
2381 None
2382 };
2383
2384 let x_padded_width = match (left_pad, right_pad) {
2386 (Some(l), Some(r)) => Tensor::cat(&[l, xs.clone(), r], 3)?,
2387 (Some(l), None) => Tensor::cat(&[l, xs.clone()], 3)?,
2388 (None, Some(r)) => Tensor::cat(&[xs.clone(), r], 3)?,
2389 (None, None) => xs.clone(),
2390 };
2391
2392 let top_pad = if pad_top > 0 {
2395 let indices: Vec<i64> = (1..=pad_top as i64).rev().collect();
2396 Some(x_padded_width.index_select(&Tensor::new(indices, &Device::Cpu)?, 2)?)
2397 } else {
2398 None
2399 };
2400
2401 let bottom_pad = if pad_bottom > 0 {
2403 let start = h as i64 - 2;
2404 let indices: Vec<i64> = (0..pad_bottom as i64).map(|i| start - i).collect();
2405 Some(x_padded_width.index_select(&Tensor::new(indices, &Device::Cpu)?, 2)?)
2406 } else {
2407 None
2408 };
2409
2410 let x_padded = match (top_pad, bottom_pad) {
2412 (Some(t), Some(b)) => Tensor::cat(&[t, x_padded_width, b], 2)?,
2413 (Some(t), None) => Tensor::cat(&[t, x_padded_width], 2)?,
2414 (None, Some(b)) => Tensor::cat(&[x_padded_width, b], 2)?,
2415 (None, None) => x_padded_width,
2416 };
2417
2418 Ok(x_padded)
2419 }
2420}
2421
2422pub struct ScaledEmbedding {
2423 scale: f64,
2424 embedding: Embedding,
2425}
2426
2427impl ScaledEmbedding {
2428 pub fn new(scale: f64, embedding: Embedding) -> Self {
2429 Self { scale, embedding }
2430 }
2431
2432 pub fn embeddings(&self) -> &Tensor {
2433 self.embedding.embeddings()
2434 }
2435}
2436
2437impl Module for ScaledEmbedding {
2438 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2439 xs.apply(&self.embedding)? * self.scale
2440 }
2441}