1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{f32::consts::PI, ops::Mul, str::FromStr, sync::Arc};
4
5use candle_core::{
6 quantized::{QMatMul, QTensor},
7 Context, DType, Device, IndexOp, Result, Tensor, D,
8};
9use candle_nn::{
10 BatchNorm, BatchNormConfig, Conv1d, Conv1dConfig, Conv2d, Conv2dConfig, Embedding, GroupNorm,
11 LayerNorm, LayerNormConfig, Linear, Module,
12};
13use float8::F8E4M3;
14use half::{bf16, f16};
15use mistralrs_quant::{
16 AfqLayer, ColumnParallelLayer, Convolution, QuantMethod, QuantizedConfig, RowParallelLayer,
17 ShardedVarBuilder,
18};
19use serde::{Deserialize, Serialize};
20
21pub use crate::attention::Sdpa;
22pub use crate::layers_masker::CausalMasker;
23pub use crate::layers_utils::repeat_kv;
24use crate::{
25 amoe::{AnyMoeTrainableLayer, MlpLayer},
26 embedding_models::embedding_gemma::EmbeddingGemmaConfig,
27 gguf::Content,
28 models::{llama, smollm3},
29 ops::SplitOp,
30 vision_models::{
31 gemma3::config::Gemma3TextConfig,
32 gemma3n::config::Gemma3nTextConfig,
33 llama4,
34 mllama::{MLlamaRopeScaling, MLlamaRopeType, MLlamaTextConfig},
35 phi4::Phi4MMConfig,
36 },
37};
38
39pub use mistralrs_quant::MatMul;
40
41pub fn embedding(
42 in_size: usize,
43 out_size: usize,
44 vb: ShardedVarBuilder,
45 config: &Option<QuantizedConfig>,
46) -> Result<Embedding> {
47 let embeddings = if let Some(QuantizedConfig::Afq { .. }) = config {
49 let afq_layer =
50 AfqLayer::afq_linear_b(out_size, in_size, config.as_ref().unwrap(), false, vb)?;
51 afq_layer.dequantize_w()?
52 } else {
53 vb.get_with_hints((in_size, out_size), "weight", Default::default())?
54 };
55 Ok(Embedding::new(embeddings, out_size))
56}
57
58pub fn layer_norm<C: Into<LayerNormConfig>>(
59 size: usize,
60 config: C,
61 vb: ShardedVarBuilder,
62) -> Result<LayerNorm> {
63 let config = config.into();
64 let weight = vb.get(size, "weight")?;
65 if config.affine {
66 let bias = vb.get(size, "bias")?;
67 Ok(LayerNorm::new(weight, bias, config.eps))
68 } else {
69 Ok(LayerNorm::new_no_bias(weight, config.eps))
70 }
71}
72
73pub fn batch_norm<C: Into<BatchNormConfig>>(
74 num_features: usize,
75 config: C,
76 vb: ShardedVarBuilder,
77) -> Result<BatchNorm> {
78 let config = config.into();
79 if config.eps < 0. {
80 candle_core::bail!("batch-norm eps cannot be negative {}", config.eps)
81 }
82 let running_mean = vb.get(num_features, "running_mean")?;
83 let running_var = vb.get(num_features, "running_var")?;
84
85 if config.affine {
86 let weight = vb.get(num_features, "weight")?;
87 let bias = vb.get(num_features, "bias")?;
88 BatchNorm::new(
89 num_features,
90 running_mean,
91 running_var,
92 weight,
93 bias,
94 config.eps,
95 )
96 } else {
97 BatchNorm::new_no_bias(num_features, running_mean, running_var, config.eps)
98 }
99}
100
101pub fn group_norm(
102 num_groups: usize,
103 num_channels: usize,
104 eps: f64,
105 vb: ShardedVarBuilder,
106) -> Result<GroupNorm> {
107 let weight = vb.get(num_channels, "weight")?;
108 let bias = vb.get(num_channels, "bias")?;
109 GroupNorm::new(weight, bias, num_channels, num_groups, eps)
110}
111
112pub fn conv2d(
113 in_channels: usize,
114 out_channels: usize,
115 kernel_size: usize,
116 cfg: Conv2dConfig,
117 vb: ShardedVarBuilder,
118) -> Result<Conv2d> {
119 let ws = vb.get(
120 (
121 out_channels,
122 in_channels / cfg.groups,
123 kernel_size,
124 kernel_size,
125 ),
126 "weight",
127 )?;
128 let bs = vb.get(out_channels, "bias")?;
129 Ok(Conv2d::new(ws, Some(bs), cfg))
130}
131
132pub fn conv2d_no_bias(
133 in_channels: usize,
134 out_channels: usize,
135 kernel_size: usize,
136 cfg: Conv2dConfig,
137 vb: ShardedVarBuilder,
138) -> Result<Conv2d> {
139 let ws = vb.get(
140 (
141 out_channels,
142 in_channels / cfg.groups,
143 kernel_size,
144 kernel_size,
145 ),
146 "weight",
147 )?;
148 Ok(Conv2d::new(ws, None, cfg))
149}
150
151pub fn conv1d(
152 in_channels: usize,
153 out_channels: usize,
154 kernel_size: usize,
155 cfg: Conv1dConfig,
156 vb: ShardedVarBuilder,
157) -> Result<Conv1d> {
158 let ws = vb.get(
159 (out_channels, in_channels / cfg.groups, kernel_size),
160 "weight",
161 )?;
162 let bs = vb.get(out_channels, "bias")?;
163 Ok(Conv1d::new(ws, Some(bs), cfg))
164}
165
166pub fn conv1d_no_bias(
167 in_channels: usize,
168 out_channels: usize,
169 kernel_size: usize,
170 cfg: Conv1dConfig,
171 vb: ShardedVarBuilder,
172) -> Result<Conv1d> {
173 let ws = vb.get(
174 (out_channels, in_channels / cfg.groups, kernel_size),
175 "weight",
176 )?;
177 Ok(Conv1d::new(ws, None, cfg))
178}
179
180pub fn linear(in_dim: usize, out_dim: usize, vb: ShardedVarBuilder) -> Result<Linear> {
181 let ws = vb.get((out_dim, in_dim), "weight")?;
182 let bs = vb.get(out_dim, "bias")?;
183 Ok(Linear::new(ws, Some(bs)))
184}
185
186pub fn linear_no_bias(in_dim: usize, out_dim: usize, vb: ShardedVarBuilder) -> Result<Linear> {
187 let ws = vb.get((out_dim, in_dim), "weight")?;
188 Ok(Linear::new(ws, None))
189}
190
191pub fn linear_b(
192 in_dim: usize,
193 out_dim: usize,
194 bias: bool,
195 vb: ShardedVarBuilder,
196) -> Result<Linear> {
197 if bias {
198 linear(in_dim, out_dim, vb)
199 } else {
200 linear_no_bias(in_dim, out_dim, vb)
201 }
202}
203
204#[derive(Debug, Clone)]
205pub struct RmsNorm {
206 eps: f64,
207 weight: Tensor,
208}
209
210impl RmsNorm {
211 pub fn new(size: usize, eps: f64, vb: ShardedVarBuilder) -> Result<Self> {
212 let w = vb.get(size, "weight")?;
213 Ok(Self { eps, weight: w })
214 }
215
216 pub fn new_gemma(size: usize, eps: f64, vb: ShardedVarBuilder) -> Result<Self> {
218 let w = vb.get(size, "weight")?;
219 let w = (w + 1.0)?;
220 Ok(Self { eps, weight: w })
221 }
222
223 pub fn new_gemma_3n(
225 size: usize,
226 eps: f64,
227 with_scale: bool,
228 vb: ShardedVarBuilder,
229 ) -> Result<Self> {
230 let w = if with_scale {
231 vb.get(size, "weight")?
232 } else {
233 Tensor::ones(size, vb.dtype(), vb.device())?
234 };
235 Ok(Self { eps, weight: w })
236 }
237
238 pub fn undo_gemma(&self) -> Result<Self> {
240 Ok(Self {
241 eps: self.eps,
242 weight: (&self.weight - 1.0)?,
243 })
244 }
245
246 pub fn from_w(w: Tensor, eps: f64) -> Result<Self> {
247 Ok(Self { eps, weight: w })
248 }
249
250 pub fn weight(&self) -> &Tensor {
251 &self.weight
252 }
253}
254
255impl Module for RmsNorm {
256 fn forward(&self, x: &Tensor) -> Result<Tensor> {
257 candle_nn::ops::rms_norm(&x.contiguous()?, &self.weight, self.eps as f32)
258 }
259}
260
261#[derive(Debug, Clone)]
262pub struct F32RmsNorm {
263 w: Tensor,
264 eps: f64,
265}
266
267impl F32RmsNorm {
268 pub fn new(size: usize, eps: f64, vb: ShardedVarBuilder) -> Result<Self> {
269 Ok(Self {
270 w: vb.get((size,), "weight")?,
271 eps,
272 })
273 }
274
275 pub fn weight(&self) -> &Tensor {
276 &self.w
277 }
278}
279
280impl Module for F32RmsNorm {
281 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
282 let initial_type = xs.dtype();
283 let mut xs = xs.to_dtype(DType::F32)?;
284 let var = xs.powf(2.)?.mean_keepdim(D::Minus1)?;
285 xs = xs.broadcast_mul(&(&var + self.eps)?.recip()?.sqrt()?)?;
286 xs.to_dtype(initial_type)?.broadcast_mul(&self.w)
287 }
288}
289
290#[derive(Debug, Clone)]
291pub struct QRmsNorm {
292 eps: f64,
293 weight: Tensor,
294}
295
296impl QRmsNorm {
297 pub fn new(scale: QTensor, eps: f32) -> Result<Self> {
298 let scale = scale.dequantize(&scale.device())?;
299 Ok(Self {
300 eps: eps as f64,
301 weight: scale,
302 })
303 }
304
305 pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
306 candle_nn::ops::rms_norm(&x.contiguous()?, &self.weight, self.eps as f32)
307 }
308}
309
310#[derive(Debug, Clone)]
312pub struct PhiRotaryEmbedding {
313 short_sin: Tensor,
314 short_cos: Tensor,
315 long_cos: Option<Tensor>,
316 long_sin: Option<Tensor>,
317 original_max_position_embeddings: usize,
318}
319
320#[derive(Debug, Clone, Deserialize, Serialize)]
321#[serde(rename_all = "lowercase")]
322pub enum ScaledRopeType {
323 #[serde(alias = "su")]
324 #[serde(alias = "longrope")]
325 Su,
326 #[serde(alias = "yarn")]
327 Yarn,
328 #[serde(alias = "dynamic")]
329 Dynamic,
330 #[serde(alias = "linear")]
331 Linear,
332}
333
334impl FromStr for ScaledRopeType {
335 type Err = candle_core::Error;
336 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
337 match s {
338 "su" | "longrope" => Ok(Self::Su),
339 "yarn" => Ok(Self::Yarn),
340 "linear" => Ok(Self::Linear),
341 "dynamic" => Ok(Self::Dynamic),
342 _ => Err(candle_core::Error::Msg(
343 "Expected either `su` or `yarn` scaled RoPE type.".to_string(),
344 )),
345 }
346 }
347}
348
349#[derive(Debug, Clone, Deserialize, Serialize)]
350#[serde(untagged)]
351pub enum PhiRopeScalingConfig {
352 Classic {
353 short_factor: Vec<f64>,
354 long_factor: Vec<f64>,
355 #[serde(rename = "type")]
356 scaling_type: ScaledRopeType,
357 },
358 Scaled {
359 short_factor: Vec<f64>,
360 long_factor: Vec<f64>,
361 #[serde(rename = "type")]
362 scaling_type: ScaledRopeType,
363 long_mscale: f64,
364 short_mscale: f64,
365 },
366}
367
368pub struct PhiRopeConfig {
369 pub rope_scaling: Option<PhiRopeScalingConfig>,
370 pub max_position_embeddings: usize,
371 pub original_max_position_embeddings: usize,
372 pub rope_theta: f64,
373 pub head_dim: usize,
374 pub partial_rotary_factor: Option<f64>,
375}
376
377impl PhiRotaryEmbedding {
378 fn new_classic_scaled(
379 short_factor: &[f64],
380 long_factor: &[f64],
381 scaling_type: &ScaledRopeType,
382 cfg: &PhiRopeConfig,
383 dtype: DType,
384 dev: &Device,
385 ) -> Result<Self> {
386 let max_seq_len = cfg.max_position_embeddings;
387 let dim = (cfg.head_dim as f64 * cfg.partial_rotary_factor.unwrap_or(1.)) as usize;
388
389 let scale =
391 cfg.max_position_embeddings as f64 / cfg.original_max_position_embeddings as f64;
392 let scaling_factor = if scale <= 1.0 {
393 1.0
394 } else {
395 match scaling_type {
396 ScaledRopeType::Su => {
397 (1.0 + scale.ln() / (cfg.original_max_position_embeddings as f64).ln()).sqrt()
398 }
399 ScaledRopeType::Yarn => 0.1 * scale.ln() + 1.0,
400 _ => candle_core::bail!("Expected either `su` or `yarn` RoPE"),
401 }
402 };
403
404 let inv_freq_long = (0..dim)
406 .step_by(2)
407 .enumerate()
408 .map(|(k, i)| {
409 (1f64 / (long_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64))) as f32
410 })
411 .collect::<Vec<_>>();
412 let inv_freq_short = (0..dim)
413 .step_by(2)
414 .enumerate()
415 .map(|(k, i)| {
416 (1f64 / (short_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64))) as f32
417 })
418 .collect::<Vec<_>>();
419 let inv_freq_len = inv_freq_long.len();
420
421 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
422 .to_dtype(DType::F32)?
423 .reshape((max_seq_len, 1))?;
424
425 let inv_freq_long = Tensor::from_vec(inv_freq_long, (1, inv_freq_len), dev)?;
427 let freqs_long = t.matmul(&inv_freq_long)?;
428 let long_sin = freqs_long.sin()?.mul(scaling_factor)?.to_dtype(dtype)?;
429 let long_cos = freqs_long.cos()?.mul(scaling_factor)?.to_dtype(dtype)?;
430
431 let inv_freq_short =
433 Tensor::from_vec(inv_freq_short, (1, inv_freq_len), dev)?.to_dtype(DType::F32)?;
434 let freqs_short = t.matmul(&inv_freq_short)?;
435 let short_sin = freqs_short.sin()?.mul(scaling_factor)?.to_dtype(dtype)?;
436 let short_cos = freqs_short.cos()?.mul(scaling_factor)?.to_dtype(dtype)?;
437
438 Ok(Self {
439 short_cos,
440 short_sin,
441 long_cos: Some(long_cos),
442 long_sin: Some(long_sin),
443 original_max_position_embeddings: cfg.original_max_position_embeddings,
444 })
445 }
446
447 fn new_unscaled(cfg: &PhiRopeConfig, dtype: DType, dev: &Device) -> Result<Self> {
448 let max_seq_len = cfg.max_position_embeddings;
449 let dim = (cfg.head_dim as f64 * cfg.partial_rotary_factor.unwrap_or(1.)) as usize;
450
451 let inv_freq: Vec<_> = (0..dim)
452 .step_by(2)
453 .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
454 .collect();
455 let inv_freq_len = inv_freq.len();
456 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
457 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
458 .to_dtype(DType::F32)?
459 .reshape((max_seq_len, 1))?;
460 let freqs = t.matmul(&inv_freq)?;
461 let sin = freqs.sin()?.to_dtype(dtype)?;
462 let cos = freqs.cos()?.to_dtype(dtype)?;
463 Ok(Self {
464 short_cos: cos,
465 short_sin: sin,
466 long_cos: None,
467 long_sin: None,
468 original_max_position_embeddings: cfg.original_max_position_embeddings,
469 })
470 }
471
472 #[allow(clippy::too_many_arguments)]
473 fn new_scaled(
474 short_factor: &[f64],
475 long_factor: &[f64],
476 scaling_type: &ScaledRopeType,
477 long_mscale: f64,
478 short_mscale: f64,
479 cfg: &PhiRopeConfig,
480 dtype: DType,
481 dev: &Device,
482 ) -> Result<Self> {
483 let max_seq_len = cfg.max_position_embeddings;
484 let dim = (cfg.head_dim as f64 * cfg.partial_rotary_factor.unwrap_or(1.)) as usize;
485
486 if !matches!(scaling_type, ScaledRopeType::Su) {
487 candle_core::bail!("Scaled Phi3 RoPE (non-classic scaled, with mscales) must have type `su`/`longrope`.");
488 }
489
490 if short_factor.len() != dim / 2 {
491 candle_core::bail!(
492 "Misaligned length {}, expected {} for `su`/`longrope` short rescale factors",
493 short_factor.len(),
494 dim / 2
495 );
496 }
497 if long_factor.len() != dim / 2 {
498 candle_core::bail!(
499 "Misaligned length {}, expected {} for `su`/`longrope` long rescale factors",
500 long_factor.len(),
501 dim / 2
502 );
503 }
504
505 let inv_freq_short: Vec<_> = (0..dim)
507 .step_by(2)
508 .enumerate()
509 .map(|(k, i)| {
510 1f32 / (short_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32
511 })
512 .collect();
513 let inv_freq_len_short = inv_freq_short.len();
514 let inv_freq_short = Tensor::from_vec(inv_freq_short, (1, inv_freq_len_short), dev)?;
515 let t_short = Tensor::arange(0u32, max_seq_len as u32, dev)?
516 .to_dtype(DType::F32)?
517 .reshape((max_seq_len, 1))?;
518 let freqs_short = t_short.matmul(&inv_freq_short)?;
519 let sin_short = (freqs_short.sin()?.to_dtype(dtype)? * short_mscale)?;
520 let cos_short = (freqs_short.cos()?.to_dtype(dtype)? * short_mscale)?;
521
522 let inv_freq_long: Vec<_> = (0..dim)
524 .step_by(2)
525 .enumerate()
526 .map(|(k, i)| {
527 1f32 / (long_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32
528 })
529 .collect();
530 let inv_freq_len_long = inv_freq_long.len();
531 let inv_freq_long = Tensor::from_vec(inv_freq_long, (1, inv_freq_len_long), dev)?;
532 let t_long = Tensor::arange(0u32, max_seq_len as u32, dev)?
533 .to_dtype(DType::F32)?
534 .reshape((max_seq_len, 1))?;
535 let freqs_long = t_long.matmul(&inv_freq_long)?;
536 let sin_long = (freqs_long.sin()?.to_dtype(dtype)? * long_mscale)?;
537 let cos_long = (freqs_long.cos()?.to_dtype(dtype)? * long_mscale)?;
538 Ok(Self {
539 short_cos: cos_short,
540 short_sin: sin_short,
541 long_cos: Some(cos_long),
542 long_sin: Some(sin_long),
543 original_max_position_embeddings: cfg.original_max_position_embeddings,
544 })
545 }
546
547 pub fn new(dtype: DType, cfg: impl Into<PhiRopeConfig>, dev: &Device) -> Result<Self> {
548 let cfg: PhiRopeConfig = cfg.into();
549
550 match &cfg.rope_scaling {
551 Some(PhiRopeScalingConfig::Classic {
552 short_factor,
553 long_factor,
554 scaling_type,
555 }) => {
556 Self::new_classic_scaled(short_factor, long_factor, scaling_type, &cfg, dtype, dev)
557 }
558
559 Some(PhiRopeScalingConfig::Scaled {
560 short_factor,
561 long_factor,
562 scaling_type,
563 long_mscale,
564 short_mscale,
565 }) => Self::new_scaled(
566 short_factor,
567 long_factor,
568 scaling_type,
569 *long_mscale,
570 *short_mscale,
571 &cfg,
572 dtype,
573 dev,
574 ),
575
576 None => Self::new_unscaled(&cfg, dtype, dev),
577 }
578 }
579
580 fn get_long_or_short_sin_cos(&self, position_ids: &[usize]) -> (&Tensor, &Tensor) {
582 if self.long_cos.is_none() {
583 return (&self.short_sin, &self.short_cos);
584 }
585 let seq_len = position_ids.iter().max().unwrap() + 1;
586 if seq_len > self.original_max_position_embeddings {
587 (
588 self.long_sin.as_ref().unwrap(),
589 self.long_cos.as_ref().unwrap(),
590 )
591 } else {
592 (&self.short_sin, &self.short_cos)
593 }
594 }
595
596 pub fn forward(
597 &self,
598 q: &Tensor,
599 k: &Tensor,
600 seqlen_offsets: &[usize],
601 position_ids: &[usize],
602 ) -> Result<(Tensor, Tensor)> {
603 let (sin, cos) = self.get_long_or_short_sin_cos(position_ids);
604 let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
605
606 let rot_dim = cos.dim(D::Minus1)? * 2;
607
608 if rot_dim != q.dim(D::Minus1)? {
610 let rot_dim = cos.dim(D::Minus1)? * 2;
611 let q_rot = q.narrow(D::Minus1, 0, rot_dim)?;
612 let q_pass = q.narrow(D::Minus1, rot_dim, q.dim(D::Minus1)? - rot_dim)?;
613 let k_rot = k.narrow(D::Minus1, 0, rot_dim)?;
614 let k_pass = k.narrow(D::Minus1, rot_dim, k.dim(D::Minus1)? - rot_dim)?;
615
616 let (q_rot, k_rot) = if seqlen_offsets.len() == 1 {
617 let cos = cos.narrow(0, seqlen_offsets[0], seq_len)?;
618 let sin = sin.narrow(0, seqlen_offsets[0], seq_len)?;
619 let q_embed = candle_nn::rotary_emb::rope(&q_rot.contiguous()?, &cos, &sin)?;
620 let k_embed = candle_nn::rotary_emb::rope(&k_rot.contiguous()?, &cos, &sin)?;
621 (q_embed, k_embed)
622 } else {
623 let mut q_embeds = Vec::new();
624 let mut k_embeds = Vec::new();
625 for (i, offset) in seqlen_offsets.iter().enumerate() {
626 let cos = cos.narrow(0, *offset, seq_len)?;
627 let sin = sin.narrow(0, *offset, seq_len)?;
628 let q_embed = candle_nn::rotary_emb::rope(
629 &q_rot.i(i)?.unsqueeze(0)?.contiguous()?,
630 &cos,
631 &sin,
632 )?;
633 let k_embed = candle_nn::rotary_emb::rope(
634 &k_rot.i(i)?.unsqueeze(0)?.contiguous()?,
635 &cos,
636 &sin,
637 )?;
638 q_embeds.push(q_embed);
639 k_embeds.push(k_embed);
640 }
641 let q_rot = Tensor::cat(&q_embeds, 0)?;
642 let k_rot = Tensor::cat(&k_embeds, 0)?;
643 (q_rot, k_rot)
644 };
645
646 Ok((
647 Tensor::cat(&[q_rot, q_pass], D::Minus1)?.contiguous()?,
648 Tensor::cat(&[k_rot, k_pass], D::Minus1)?.contiguous()?,
649 ))
650 } else if seqlen_offsets.len() == 1 {
651 let cos = cos.narrow(0, seqlen_offsets[0], seq_len)?;
652 let sin = sin.narrow(0, seqlen_offsets[0], seq_len)?;
653 let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
654 let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
655 Ok((q_embed, k_embed))
656 } else {
657 let mut q_embeds = Vec::new();
658 let mut k_embeds = Vec::new();
659 for (i, offset) in seqlen_offsets.iter().enumerate() {
660 let cos = cos.narrow(0, *offset, seq_len)?;
661 let sin = sin.narrow(0, *offset, seq_len)?;
662 let q_embed =
663 candle_nn::rotary_emb::rope(&q.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
664 let k_embed =
665 candle_nn::rotary_emb::rope(&k.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
666 q_embeds.push(q_embed);
667 k_embeds.push(k_embed);
668 }
669 Ok((Tensor::cat(&q_embeds, 0)?, Tensor::cat(&k_embeds, 0)?))
670 }
671 }
672}
673
674#[derive(Debug, Clone)]
676pub struct Llama3RotaryEmbedding(RotaryEmbedding);
677
678#[derive(Debug, Clone, Deserialize, Serialize, Default)]
679pub enum Llama3RopeType {
680 #[serde(rename = "llama3")]
681 Llama3,
682 #[serde(rename = "linear")]
683 Linear,
684 #[default]
685 #[serde(rename = "default")]
686 Default,
687}
688
689#[derive(Debug, Clone, Deserialize, Serialize, Default)]
690pub struct Llama3RopeConfig {
691 pub factor: f32,
692 pub low_freq_factor: Option<f32>,
693 pub high_freq_factor: Option<f32>,
694 pub original_max_position_embeddings: Option<usize>,
695 pub rope_type: Llama3RopeType,
696}
697
698fn calculate_default_inv_freq(cfg: &llama::Config) -> Vec<f32> {
699 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
700 (0..head_dim)
701 .step_by(2)
702 .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
703 .collect()
704}
705
706fn calculate_default_inv_freq_llama4(cfg: &llama4::TextConfig) -> Vec<f32> {
707 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
708 (0..head_dim)
709 .step_by(2)
710 .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
711 .collect()
712}
713
714impl Llama3RotaryEmbedding {
716 pub fn new_llama3(
717 dtype: DType,
718 cfg: &llama::Config,
719 dev: &Device,
720 is_gpt_neox: bool,
721 ) -> Result<Self> {
722 match &cfg.rope_scaling {
723 None
724 | Some(Llama3RopeConfig {
725 rope_type: Llama3RopeType::Default,
726 ..
727 }) => Ok(Self(RotaryEmbedding::new(
728 cfg.rope_theta,
729 cfg.hidden_size / cfg.num_attention_heads,
730 cfg.max_position_embeddings,
731 dev,
732 is_gpt_neox,
733 dtype,
734 )?)),
735 Some(Llama3RopeConfig {
736 rope_type: Llama3RopeType::Llama3,
737 factor,
738 low_freq_factor,
739 high_freq_factor,
740 original_max_position_embeddings,
741 }) => {
742 let low_freq_factor = low_freq_factor.context("low_freq_factor is required")?;
743 let high_freq_factor = high_freq_factor.context("high_freq_factor is required")?;
744 let original_max_position_embeddings = original_max_position_embeddings
745 .context("original_max_position_embeddings is required")?;
746
747 let low_freq_wavelen = original_max_position_embeddings as f32 / low_freq_factor;
748 let high_freq_wavelen = original_max_position_embeddings as f32 / high_freq_factor;
749
750 let inv_freq = calculate_default_inv_freq(cfg)
751 .into_iter()
752 .map(|freq| {
753 let wavelen = 2. * PI / freq;
754 if wavelen < high_freq_wavelen {
755 freq
756 } else if wavelen > low_freq_wavelen {
757 freq / *factor
758 } else {
759 let smooth = (original_max_position_embeddings as f32 / wavelen
760 - low_freq_factor)
761 / (high_freq_factor - low_freq_factor);
762 (1. - smooth) * freq / *factor + smooth * freq
763 }
764 })
765 .collect::<Vec<_>>();
766 let inv_freq_len = inv_freq.len();
767 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
768 let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
769 .to_dtype(DType::F32)?
770 .reshape((cfg.max_position_embeddings, 1))?;
771 let freqs = t.matmul(&inv_freq)?;
772 let sin = freqs.sin()?.to_dtype(dtype)?;
773 let cos = freqs.cos()?.to_dtype(dtype)?;
774 Ok(Self(RotaryEmbedding {
775 sin,
776 cos,
777 is_gpt_neox,
778 }))
779 }
780 Some(Llama3RopeConfig {
781 rope_type: Llama3RopeType::Linear,
782 factor,
783 ..
784 }) => {
785 let inv_freq_vec = calculate_default_inv_freq(cfg)
786 .into_iter()
787 .map(|freq| freq / *factor)
788 .collect::<Vec<_>>();
789 let inv_freq_len = inv_freq_vec.len();
790 let inv_freq = Tensor::from_vec(inv_freq_vec, (1, inv_freq_len), dev)?;
791 let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
792 .to_dtype(DType::F32)?
793 .reshape((cfg.max_position_embeddings, 1))?;
794 let freqs = t.matmul(&inv_freq)?;
795 let sin = freqs.sin()?.to_dtype(dtype)?;
796 let cos = freqs.cos()?.to_dtype(dtype)?;
797 Ok(Self(RotaryEmbedding {
798 sin,
799 cos,
800 is_gpt_neox,
801 }))
802 }
803 }
804 }
805
806 pub fn new_llama4(
807 dtype: DType,
808 cfg: &llama4::TextConfig,
809 dev: &Device,
810 is_gpt_neox: bool,
811 ) -> Result<Self> {
812 match &cfg.rope_scaling {
813 None
814 | Some(Llama3RopeConfig {
815 rope_type: Llama3RopeType::Default,
816 ..
817 }) => Ok(Self(RotaryEmbedding::new(
818 cfg.rope_theta,
819 cfg.hidden_size / cfg.num_attention_heads,
820 cfg.max_position_embeddings,
821 dev,
822 is_gpt_neox,
823 dtype,
824 )?)),
825 Some(Llama3RopeConfig {
826 rope_type: Llama3RopeType::Llama3,
827 factor,
828 low_freq_factor,
829 high_freq_factor,
830 original_max_position_embeddings,
831 }) => {
832 let low_freq_factor = low_freq_factor.context("low_freq_factor is required")?;
833 let high_freq_factor = high_freq_factor.context("high_freq_factor is required")?;
834 let original_max_position_embeddings = original_max_position_embeddings
835 .context("original_max_position_embeddings is required")?;
836
837 let low_freq_wavelen = original_max_position_embeddings as f32 / low_freq_factor;
838 let high_freq_wavelen = original_max_position_embeddings as f32 / high_freq_factor;
839
840 let inv_freq = calculate_default_inv_freq_llama4(cfg)
841 .into_iter()
842 .map(|freq| {
843 let wavelen = 2. * PI / freq;
844 if wavelen < high_freq_wavelen {
845 freq
846 } else if wavelen > low_freq_wavelen {
847 freq / *factor
848 } else {
849 let smooth = (original_max_position_embeddings as f32 / wavelen
850 - low_freq_factor)
851 / (high_freq_factor - low_freq_factor);
852 (1. - smooth) * freq / *factor + smooth * freq
853 }
854 })
855 .collect::<Vec<_>>();
856 let inv_freq_len = inv_freq.len();
857 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
858 let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
859 .to_dtype(DType::F32)?
860 .reshape((cfg.max_position_embeddings, 1))?;
861 let freqs = t.matmul(&inv_freq)?;
862 let sin = freqs.sin()?.to_dtype(dtype)?;
863 let cos = freqs.cos()?.to_dtype(dtype)?;
864 Ok(Self(RotaryEmbedding {
865 sin,
866 cos,
867 is_gpt_neox,
868 }))
869 }
870 Some(Llama3RopeConfig {
871 rope_type: Llama3RopeType::Linear,
872 factor,
873 ..
874 }) => {
875 let inv_freq_vec = calculate_default_inv_freq_llama4(cfg)
876 .into_iter()
877 .map(|freq| freq / *factor)
878 .collect::<Vec<_>>();
879 let inv_freq_len = inv_freq_vec.len();
880 let inv_freq = Tensor::from_vec(inv_freq_vec, (1, inv_freq_len), dev)?;
881 let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
882 .to_dtype(DType::F32)?
883 .reshape((cfg.max_position_embeddings, 1))?;
884 let freqs = t.matmul(&inv_freq)?;
885 let sin = freqs.sin()?.to_dtype(dtype)?;
886 let cos = freqs.cos()?.to_dtype(dtype)?;
887 Ok(Self(RotaryEmbedding {
888 sin,
889 cos,
890 is_gpt_neox,
891 }))
892 }
893 }
894 }
895
896 pub fn new_mllama3(
897 dtype: DType,
898 cfg: &MLlamaTextConfig,
899 dev: &Device,
900 is_gpt_neox: bool,
901 ) -> Result<Self> {
902 match &cfg.rope_scaling {
903 None
904 | Some(MLlamaRopeScaling {
905 rope_type: MLlamaRopeType::Default,
906 ..
907 }) => Ok(Self(RotaryEmbedding::new(
908 cfg.rope_theta,
909 cfg.hidden_size / cfg.num_attention_heads,
910 cfg.max_position_embeddings,
911 dev,
912 is_gpt_neox,
913 dtype,
914 )?)),
915 Some(MLlamaRopeScaling {
916 rope_type: MLlamaRopeType::Llama3,
917 original_max_position_embeddings,
918 factor,
919 attention_factor: _,
920 beta_fast: _,
921 beta_slow: _,
922 short_factor: _,
923 long_factor: _,
924 low_freq_factor,
925 high_freq_factor,
926 }) => {
927 let factor = factor.context("MLlama Llama3 RoPE needs `factor` parameter.")?;
928 let low_freq_factor = low_freq_factor
929 .context("MLlama Llama3 RoPE needs `low_freq_factor` parameter.")?;
930 let high_freq_factor = high_freq_factor
931 .context("MLlama Llama3 RoPE needs `high_freq_factor` parameter.")?;
932
933 let low_freq_wavelen = *original_max_position_embeddings as f32 / low_freq_factor;
934 let high_freq_wavelen = *original_max_position_embeddings as f32 / high_freq_factor;
935
936 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
937
938 let inv_freq = (0..head_dim)
939 .step_by(2)
940 .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
941 .map(|freq| {
942 let wavelen = 2. * PI / freq;
943 if wavelen < high_freq_wavelen {
944 freq
945 } else if wavelen > low_freq_wavelen {
946 freq / factor
947 } else {
948 let smooth = (*original_max_position_embeddings as f32 / wavelen
949 - low_freq_factor)
950 / (high_freq_factor - low_freq_factor);
951 (1. - smooth) * freq / factor + smooth * freq
952 }
953 })
954 .collect::<Vec<_>>();
955 let inv_freq_len = inv_freq.len();
956 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
957
958 let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
959 .to_dtype(DType::F32)?
960 .reshape((cfg.max_position_embeddings, 1))?;
961 let freqs = t.matmul(&inv_freq)?;
962 let sin = freqs.sin()?.to_dtype(dtype)?;
963 let cos = freqs.cos()?.to_dtype(dtype)?;
964 Ok(Self(RotaryEmbedding {
965 sin,
966 cos,
967 is_gpt_neox,
968 }))
969 }
970 Some(MLlamaRopeScaling {
971 rope_type: other, ..
972 }) => {
973 candle_core::bail!(
974 "MLlama doesn't support any other RoPE type than `llama3`, got {other:?}"
975 )
976 }
977 }
978 }
979
980 pub fn forward(
981 &self,
982 q: &Tensor,
983 k: &Tensor,
984 seqlen_offsets: &[usize],
985 ) -> Result<(Tensor, Tensor)> {
986 self.0.forward(q, k, seqlen_offsets)
987 }
988}
989
990#[derive(Debug, Clone)]
992pub struct SmolLm3RotaryEmbedding(RotaryEmbedding);
993
994#[derive(Debug, Clone, Deserialize, Serialize, Default)]
995pub enum SmolLm3RopeType {
996 #[serde(rename = "llama3")]
997 Llama3,
998 #[serde(rename = "linear")]
999 Linear,
1000 #[default]
1001 #[serde(rename = "default")]
1002 Default,
1003}
1004
1005#[derive(Debug, Clone, Deserialize, Serialize, Default)]
1006pub struct SmolLm3RopeConfig {
1007 pub factor: f32,
1008 pub low_freq_factor: Option<f32>,
1009 pub high_freq_factor: Option<f32>,
1010 pub original_max_position_embeddings: Option<usize>,
1011 pub rope_type: SmolLm3RopeType,
1012}
1013
1014fn calculate_default_inv_freq_smollm3(cfg: &smollm3::Config) -> Vec<f32> {
1015 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
1016 (0..head_dim)
1017 .step_by(2)
1018 .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
1019 .collect()
1020}
1021
1022impl SmolLm3RotaryEmbedding {
1023 pub fn new_llama3(
1024 dtype: DType,
1025 cfg: &smollm3::Config,
1026 dev: &Device,
1027 is_gpt_neox: bool,
1028 ) -> Result<Self> {
1029 match &cfg.rope_scaling {
1030 None
1031 | Some(SmolLm3RopeConfig {
1032 rope_type: SmolLm3RopeType::Default,
1033 ..
1034 }) => Ok(Self(RotaryEmbedding::new(
1035 cfg.rope_theta,
1036 cfg.hidden_size / cfg.num_attention_heads,
1037 cfg.max_position_embeddings,
1038 dev,
1039 is_gpt_neox,
1040 dtype,
1041 )?)),
1042 Some(SmolLm3RopeConfig {
1043 rope_type: SmolLm3RopeType::Llama3,
1044 factor,
1045 low_freq_factor,
1046 high_freq_factor,
1047 original_max_position_embeddings,
1048 }) => {
1049 let low_freq_factor = low_freq_factor.context("low_freq_factor is required")?;
1050 let high_freq_factor = high_freq_factor.context("high_freq_factor is required")?;
1051 let original_max_position_embeddings = original_max_position_embeddings
1052 .context("original_max_position_embeddings is required")?;
1053
1054 let low_freq_wavelen = original_max_position_embeddings as f32 / low_freq_factor;
1055 let high_freq_wavelen = original_max_position_embeddings as f32 / high_freq_factor;
1056
1057 let inv_freq = calculate_default_inv_freq_smollm3(cfg)
1058 .into_iter()
1059 .map(|freq| {
1060 let wavelen = 2. * PI / freq;
1061 if wavelen < high_freq_wavelen {
1062 freq
1063 } else if wavelen > low_freq_wavelen {
1064 freq / *factor
1065 } else {
1066 let smooth = (original_max_position_embeddings as f32 / wavelen
1067 - low_freq_factor)
1068 / (high_freq_factor - low_freq_factor);
1069 (1. - smooth) * freq / *factor + smooth * freq
1070 }
1071 })
1072 .collect::<Vec<_>>();
1073 let inv_freq_len = inv_freq.len();
1074 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
1075 let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
1076 .to_dtype(DType::F32)?
1077 .reshape((cfg.max_position_embeddings, 1))?;
1078 let freqs = t.matmul(&inv_freq)?;
1079 let sin = freqs.sin()?.to_dtype(dtype)?;
1080 let cos = freqs.cos()?.to_dtype(dtype)?;
1081 Ok(Self(RotaryEmbedding {
1082 sin,
1083 cos,
1084 is_gpt_neox,
1085 }))
1086 }
1087 Some(SmolLm3RopeConfig {
1088 rope_type: SmolLm3RopeType::Linear,
1089 factor,
1090 ..
1091 }) => {
1092 let inv_freq_vec = calculate_default_inv_freq_smollm3(cfg)
1093 .into_iter()
1094 .map(|freq| freq / *factor)
1095 .collect::<Vec<_>>();
1096 let inv_freq_len = inv_freq_vec.len();
1097 let inv_freq = Tensor::from_vec(inv_freq_vec, (1, inv_freq_len), dev)?;
1098 let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
1099 .to_dtype(DType::F32)?
1100 .reshape((cfg.max_position_embeddings, 1))?;
1101 let freqs = t.matmul(&inv_freq)?;
1102 let sin = freqs.sin()?.to_dtype(dtype)?;
1103 let cos = freqs.cos()?.to_dtype(dtype)?;
1104 Ok(Self(RotaryEmbedding {
1105 sin,
1106 cos,
1107 is_gpt_neox,
1108 }))
1109 }
1110 }
1111 }
1112
1113 pub fn forward(
1114 &self,
1115 q: &Tensor,
1116 k: &Tensor,
1117 seqlen_offsets: &[usize],
1118 ) -> Result<(Tensor, Tensor)> {
1119 self.0.forward(q, k, seqlen_offsets)
1120 }
1121}
1122
1123#[derive(Debug, Clone)]
1125pub struct Qwen2VLRotaryEmbedding {
1126 inv_freq: Tensor,
1127 mrope_section: Vec<usize>,
1128}
1129
1130impl Qwen2VLRotaryEmbedding {
1131 pub fn new(
1132 base: f32,
1133 head_dim: usize,
1134 device: &Device,
1135 mrope_section: Vec<usize>,
1136 ) -> Result<Self> {
1137 let inv_freq: Vec<_> = (0..head_dim)
1138 .step_by(2)
1139 .map(|i| 1f32 / base.powf(i as f32 / head_dim as f32))
1140 .collect();
1141 let inv_freq_len = inv_freq.len();
1142 let inv_freq = Tensor::from_vec(inv_freq, (inv_freq_len,), device)?.to_dtype(DType::F32)?;
1143 Ok(Self {
1144 inv_freq,
1145 mrope_section,
1146 })
1147 }
1148
1149 pub fn compute_cos_sin(&self, position_ids: &Tensor, dtype: DType) -> Result<(Tensor, Tensor)> {
1151 let inv_freq_expanded =
1152 self.inv_freq
1153 .reshape((1, 1, (), 1))?
1154 .repeat((3, position_ids.dim(1)?, 1, 1))?;
1155 let position_ids_expanded = position_ids.unsqueeze(2)?;
1156 let freqs = inv_freq_expanded
1157 .matmul(&position_ids_expanded.to_dtype(inv_freq_expanded.dtype())?)?
1158 .transpose(2, 3)?;
1159 let cos = freqs.cos()?;
1160 let sin = freqs.sin()?;
1161
1162 let cos = Tensor::cat(
1163 &cos.split(&self.mrope_section, D::Minus1)?
1164 .into_iter()
1165 .enumerate()
1166 .map(|(i, m)| m.i(i % 3))
1167 .collect::<Result<Vec<_>>>()?,
1168 D::Minus1,
1169 )?
1170 .squeeze(0)?
1171 .to_dtype(dtype)?
1172 .contiguous()?;
1173 let sin = Tensor::cat(
1174 &sin.split(&self.mrope_section, D::Minus1)?
1175 .into_iter()
1176 .enumerate()
1177 .map(|(i, m)| m.i(i % 3))
1178 .collect::<Result<Vec<_>>>()?,
1179 D::Minus1,
1180 )?
1181 .squeeze(0)?
1182 .to_dtype(dtype)?
1183 .contiguous()?;
1184
1185 Ok((cos, sin))
1186 }
1187
1188 pub fn forward(
1190 &self,
1191 (cos, sin): &(Tensor, Tensor),
1192 q: &mut Tensor,
1193 k: &mut Tensor,
1194 ) -> Result<()> {
1195 *q = candle_nn::rotary_emb::rope(&q.contiguous()?, cos, sin)?;
1196 *k = candle_nn::rotary_emb::rope(&k.contiguous()?, cos, sin)?;
1197 Ok(())
1198 }
1199}
1200
1201#[derive(Debug, Clone)]
1202pub struct Qwen2_5VLRotaryEmbedding {
1203 inv_freq: Tensor,
1204 mrope_section: Vec<usize>,
1205}
1206
1207impl Qwen2_5VLRotaryEmbedding {
1208 pub fn new(
1209 base: f32,
1210 head_dim: usize,
1211 device: &Device,
1212 mrope_section: Vec<usize>,
1213 ) -> Result<Self> {
1214 let inv_freq: Vec<_> = (0..head_dim)
1215 .step_by(2)
1216 .map(|i| 1f32 / base.powf(i as f32 / head_dim as f32))
1217 .collect();
1218 let inv_freq_len = inv_freq.len();
1219 let inv_freq = Tensor::from_vec(inv_freq, (inv_freq_len,), device)?.to_dtype(DType::F32)?;
1220 Ok(Self {
1221 inv_freq,
1222 mrope_section,
1223 })
1224 }
1225
1226 pub fn compute_cos_sin(&self, position_ids: &Tensor, dtype: DType) -> Result<(Tensor, Tensor)> {
1228 let inv_freq_expanded =
1229 self.inv_freq
1230 .reshape((1, 1, (), 1))?
1231 .repeat((3, position_ids.dim(1)?, 1, 1))?;
1232 let position_ids_expanded = position_ids.unsqueeze(2)?;
1233 let freqs = inv_freq_expanded
1234 .matmul(&position_ids_expanded.to_dtype(inv_freq_expanded.dtype())?)?
1235 .transpose(2, 3)?;
1236 let cos = freqs.cos()?;
1237 let sin = freqs.sin()?;
1238
1239 let cos = Tensor::cat(
1240 &cos.split(&self.mrope_section, D::Minus1)?
1241 .into_iter()
1242 .enumerate()
1243 .map(|(i, m)| m.i(i % 3))
1244 .collect::<Result<Vec<_>>>()?,
1245 D::Minus1,
1246 )?
1247 .squeeze(0)?
1248 .to_dtype(dtype)?
1249 .contiguous()?;
1250 let sin = Tensor::cat(
1251 &sin.split(&self.mrope_section, D::Minus1)?
1252 .into_iter()
1253 .enumerate()
1254 .map(|(i, m)| m.i(i % 3))
1255 .collect::<Result<Vec<_>>>()?,
1256 D::Minus1,
1257 )?
1258 .squeeze(0)?
1259 .to_dtype(dtype)?
1260 .contiguous()?;
1261
1262 Ok((cos, sin))
1263 }
1264
1265 pub fn forward(
1266 &self,
1267 (cos, sin): &(Tensor, Tensor),
1268 q: &mut Tensor,
1269 k: &mut Tensor,
1270 ) -> Result<()> {
1271 *q = candle_nn::rotary_emb::rope(&q.contiguous()?, cos, sin)?;
1272 *k = candle_nn::rotary_emb::rope(&k.contiguous()?, cos, sin)?;
1273 Ok(())
1274 }
1275}
1276
1277#[derive(Debug, Clone)]
1278pub struct Qwen3VLRotaryEmbedding {
1279 inv_freq: Tensor,
1280 mrope_section: Vec<usize>,
1281}
1282
1283impl Qwen3VLRotaryEmbedding {
1284 pub fn new(
1285 base: f32,
1286 head_dim: usize,
1287 device: &Device,
1288 mrope_section: Vec<usize>,
1289 ) -> Result<Self> {
1290 let inv_freq: Vec<_> = (0..head_dim)
1291 .step_by(2)
1292 .map(|i| 1f32 / base.powf(i as f32 / head_dim as f32))
1293 .collect();
1294 let inv_freq_len = inv_freq.len();
1295 let inv_freq = Tensor::from_vec(inv_freq, (inv_freq_len,), device)?.to_dtype(DType::F32)?;
1296 Ok(Self {
1297 inv_freq,
1298 mrope_section,
1299 })
1300 }
1301
1302 pub fn compute_cos_sin(&self, position_ids: &Tensor, dtype: DType) -> Result<(Tensor, Tensor)> {
1304 let inv_freq_expanded =
1305 self.inv_freq
1306 .reshape((1, 1, (), 1))?
1307 .repeat((3, position_ids.dim(1)?, 1, 1))?;
1308 let position_ids_expanded = position_ids.unsqueeze(2)?;
1309 let freqs = inv_freq_expanded
1310 .matmul(&position_ids_expanded.to_dtype(inv_freq_expanded.dtype())?)?
1311 .transpose(2, 3)?;
1312 let cos = freqs.cos()?;
1313 let sin = freqs.sin()?;
1314
1315 let cos = Tensor::cat(
1316 &cos.split(&self.mrope_section, D::Minus1)?
1317 .into_iter()
1318 .enumerate()
1319 .map(|(i, m)| m.i(i % 3))
1320 .collect::<Result<Vec<_>>>()?,
1321 D::Minus1,
1322 )?
1323 .squeeze(0)?
1324 .to_dtype(dtype)?
1325 .contiguous()?;
1326 let sin = Tensor::cat(
1327 &sin.split(&self.mrope_section, D::Minus1)?
1328 .into_iter()
1329 .enumerate()
1330 .map(|(i, m)| m.i(i % 3))
1331 .collect::<Result<Vec<_>>>()?,
1332 D::Minus1,
1333 )?
1334 .squeeze(0)?
1335 .to_dtype(dtype)?
1336 .contiguous()?;
1337
1338 Ok((cos, sin))
1339 }
1340
1341 pub fn forward(
1342 &self,
1343 (cos, sin): &(Tensor, Tensor),
1344 q: &mut Tensor,
1345 k: &mut Tensor,
1346 ) -> Result<()> {
1347 *q = candle_nn::rotary_emb::rope(&q.contiguous()?, cos, sin)?;
1348 *k = candle_nn::rotary_emb::rope(&k.contiguous()?, cos, sin)?;
1349 Ok(())
1350 }
1351}
1352
1353#[derive(Debug, Clone)]
1354pub struct DeepSeekV2RotaryEmbedding {
1355 sin: Tensor,
1356 cos: Tensor,
1357}
1358
1359#[derive(Debug, Clone, Deserialize, Serialize)]
1360#[serde(untagged)]
1361pub enum DeepSeekV2RopeScaling {
1362 Yarn {
1363 original_max_position_embeddings: usize,
1364 beta_fast: f32,
1365 beta_slow: f32,
1366 mscale: f32,
1367 mscale_all_dim: f32,
1368 factor: f32,
1369 #[serde(rename = "type")]
1370 scaling_type: ScaledRopeType,
1371 },
1372 LinearOrDynamic {
1373 #[serde(rename = "type")]
1374 scaling_type: ScaledRopeType,
1375 factor: f64,
1376 },
1377}
1378
1379pub struct DeepSeekV2RopeConfig {
1380 pub rope_scaling: Option<DeepSeekV2RopeScaling>,
1381 pub max_position_embeddings: usize,
1382 pub rope_theta: f32,
1383 pub qk_rope_head_dim: usize,
1384}
1385
1386impl DeepSeekV2RotaryEmbedding {
1387 fn new_unscaled(cfg: &DeepSeekV2RopeConfig, dtype: DType, dev: &Device) -> Result<Self> {
1388 let max_seq_len = cfg.max_position_embeddings;
1389 let dim = cfg.qk_rope_head_dim;
1390
1391 let inv_freq: Vec<_> = (0..dim)
1392 .step_by(2)
1393 .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / dim as f32))
1394 .collect();
1395 let inv_freq_len = inv_freq.len();
1396 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
1397 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
1398 .to_dtype(DType::F32)?
1399 .reshape((max_seq_len, 1))?;
1400 let freqs = t.matmul(&inv_freq)?;
1401
1402 let sin = freqs.sin()?.to_dtype(dtype)?;
1403 let cos = freqs.cos()?.to_dtype(dtype)?;
1404
1405 Ok(Self { sin, cos })
1406 }
1407
1408 fn yarn_find_correction_dim(
1409 num_rot: f32,
1410 dim: usize,
1411 base: f32,
1412 max_position_embeddings: usize,
1413 ) -> f32 {
1414 (dim as f32 * (max_position_embeddings as f32 / (num_rot * 2. * PI)).ln())
1415 / (2. * base.ln())
1416 }
1417
1418 fn yarn_find_correction_range(
1419 low_rot: f32,
1420 high_rot: f32,
1421 dim: usize,
1422 base: f32,
1423 max_position_embeddings: usize,
1424 ) -> (f32, f32) {
1425 let low =
1426 Self::yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings).floor();
1427 let high =
1428 Self::yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings).ceil();
1429 (low.max(0.), high.min(dim as f32 - 1.))
1430 }
1431
1432 fn yarn_linear_ramp_mask(min: f32, mut max: f32, dim: usize, dev: &Device) -> Result<Tensor> {
1433 if min == max {
1434 max += 0.001;
1436 }
1437 let linear_func =
1438 ((Tensor::arange(0f32, dim as f32, dev)? - min as f64)? / (max as f64 - min as f64))?;
1439 linear_func.clamp(0., 1)
1440 }
1441
1442 pub(crate) fn yarn_get_mscale(scale: f32, mscale: f32) -> f32 {
1443 if scale <= 1. {
1444 return 1.;
1445 }
1446 0.1 * mscale * scale.ln() + 1.
1447 }
1448
1449 #[allow(clippy::too_many_arguments)]
1450 fn new_yarn(
1451 cfg: &DeepSeekV2RopeConfig,
1452 dtype: DType,
1453 dev: &Device,
1454 original_max_position_embeddings: usize,
1455 beta_fast: f32,
1456 beta_slow: f32,
1457 factor: f32,
1458 mscale: f32,
1459 mscale_all_dim: f32,
1460 ) -> Result<Self> {
1461 let freq_extra: Vec<_> = (0..cfg.qk_rope_head_dim)
1462 .step_by(2)
1463 .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / cfg.qk_rope_head_dim as f32))
1464 .collect();
1465 let freq_extra_len = freq_extra.len();
1466 let freq_extra = Tensor::from_vec(freq_extra, freq_extra_len, dev)?;
1467 let freq_inter: Vec<_> = (0..cfg.qk_rope_head_dim)
1468 .step_by(2)
1469 .map(|i| 1f32 / (factor * cfg.rope_theta.powf(i as f32 / cfg.qk_rope_head_dim as f32)))
1470 .collect();
1471 let freq_inter_len = freq_inter.len();
1472 let freq_inter = Tensor::from_vec(freq_inter, (1, freq_inter_len), dev)?;
1473
1474 let (low, high) = Self::yarn_find_correction_range(
1475 beta_fast,
1476 beta_slow,
1477 cfg.qk_rope_head_dim,
1478 cfg.rope_theta,
1479 original_max_position_embeddings,
1480 );
1481 let inv_freq_mask =
1482 (1. - Self::yarn_linear_ramp_mask(low, high, cfg.qk_rope_head_dim / 2, dev)?)?;
1483 let inv_freq = freq_inter
1484 .broadcast_mul(&(1. - &inv_freq_mask)?)?
1485 .broadcast_add(&freq_extra.broadcast_mul(&inv_freq_mask)?)?;
1486
1487 let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
1488 .to_dtype(DType::F32)?
1489 .reshape((cfg.max_position_embeddings, 1))?;
1490 let freqs = t.matmul(&inv_freq)?;
1491
1492 let mscale =
1493 Self::yarn_get_mscale(factor, mscale) / Self::yarn_get_mscale(factor, mscale_all_dim);
1494 let sin = (freqs.sin()? * mscale as f64)?.to_dtype(dtype)?;
1495 let cos = (freqs.cos()? * mscale as f64)?.to_dtype(dtype)?;
1496
1497 Ok(Self { sin, cos })
1498 }
1499
1500 pub fn new(cfg: &DeepSeekV2RopeConfig, dtype: DType, dev: &Device) -> Result<Self> {
1501 match &cfg.rope_scaling {
1502 Some(DeepSeekV2RopeScaling::LinearOrDynamic {
1503 scaling_type: _,
1504 factor: _,
1505 }) => candle_core::bail!("linear and dynamic rope are not implemented yet!"),
1506 Some(DeepSeekV2RopeScaling::Yarn {
1507 original_max_position_embeddings,
1508 beta_fast,
1509 beta_slow,
1510 factor,
1511 mscale,
1512 mscale_all_dim,
1513 scaling_type: _,
1514 }) => Self::new_yarn(
1515 cfg,
1516 dtype,
1517 dev,
1518 *original_max_position_embeddings,
1519 *beta_fast,
1520 *beta_slow,
1521 *factor,
1522 *mscale,
1523 *mscale_all_dim,
1524 ),
1525 None => Self::new_unscaled(cfg, dtype, dev),
1526 }
1527 }
1528
1529 pub fn forward(
1530 &self,
1531 q: &Tensor,
1532 k: &Tensor,
1533 seqlen_offsets: &[usize],
1534 ) -> Result<(Tensor, Tensor)> {
1535 let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
1536
1537 if seqlen_offsets.len() == 1 {
1538 let cos = self.cos.narrow(0, seqlen_offsets[0], seq_len)?;
1539 let sin = self.sin.narrow(0, seqlen_offsets[0], seq_len)?;
1540 let q_embed = candle_nn::rotary_emb::rope_i(&q.contiguous()?, &cos, &sin)?;
1541 let k_embed = candle_nn::rotary_emb::rope_i(&k.contiguous()?, &cos, &sin)?;
1542 Ok((q_embed, k_embed))
1543 } else {
1544 let mut q_embeds = Vec::new();
1545 let mut k_embeds = Vec::new();
1546 for (i, offset) in seqlen_offsets.iter().enumerate() {
1547 let cos = self.cos.narrow(0, *offset, seq_len)?;
1548 let sin = self.sin.narrow(0, *offset, seq_len)?;
1549 let q_embed = candle_nn::rotary_emb::rope_i(
1550 &q.i(i)?.unsqueeze(0)?.contiguous()?,
1551 &cos,
1552 &sin,
1553 )?;
1554 let k_embed = candle_nn::rotary_emb::rope_i(
1555 &k.i(i)?.unsqueeze(0)?.contiguous()?,
1556 &cos,
1557 &sin,
1558 )?;
1559 q_embeds.push(q_embed);
1560 k_embeds.push(k_embed);
1561 }
1562 Ok((Tensor::cat(&q_embeds, 0)?, Tensor::cat(&k_embeds, 0)?))
1563 }
1564 }
1565}
1566
1567#[derive(Debug, Clone)]
1568pub struct Phi4MMRotaryEmbedding {
1569 short_sin: Tensor,
1570 short_cos: Tensor,
1571 long_cos: Option<Tensor>,
1572 long_sin: Option<Tensor>,
1573 original_max_position_embeddings: usize,
1574}
1575
1576#[derive(Debug, Clone, Default, Deserialize, Serialize)]
1577#[serde(rename_all = "lowercase")]
1578pub enum Phi4MMScaledRopeType {
1579 #[serde(alias = "longrope")]
1580 LongRope,
1581 #[default]
1582 Default,
1583}
1584
1585#[derive(Debug, Clone, Deserialize, Serialize)]
1586pub struct Phi4MMRopeScalingConfig {
1587 short_factor: Option<Vec<f64>>,
1588 long_factor: Option<Vec<f64>>,
1589 #[serde(rename = "type")]
1590 scaling_type: Phi4MMScaledRopeType,
1591}
1592
1593impl Phi4MMRotaryEmbedding {
1594 fn new_unscaled(cfg: &Phi4MMConfig, dtype: DType, dev: &Device) -> Result<Self> {
1595 let max_seq_len = cfg.max_position_embeddings;
1596 let dim = (cfg.head_dim() as f64 * cfg.partial_rotary_factor) as usize;
1597
1598 let inv_freq: Vec<_> = (0..dim)
1599 .step_by(2)
1600 .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
1601 .collect();
1602 let inv_freq_len = inv_freq.len();
1603 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
1604 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
1605 .to_dtype(DType::F32)?
1606 .reshape((max_seq_len, 1))?;
1607 let freqs = t.matmul(&inv_freq)?;
1608 let sin = freqs.sin()?.to_dtype(dtype)?;
1609 let cos = freqs.cos()?.to_dtype(dtype)?;
1610 Ok(Self {
1611 short_cos: cos,
1612 short_sin: sin,
1613 long_cos: None,
1614 long_sin: None,
1615 original_max_position_embeddings: cfg.original_max_position_embeddings,
1616 })
1617 }
1618
1619 #[allow(clippy::too_many_arguments)]
1620 fn new_longrope(
1621 short_factor: &[f64],
1622 long_factor: &[f64],
1623 cfg: &Phi4MMConfig,
1624 dtype: DType,
1625 dev: &Device,
1626 ) -> Result<Self> {
1627 let max_seq_len = cfg.max_position_embeddings;
1628 let dim = (cfg.head_dim() as f64 * cfg.partial_rotary_factor) as usize;
1629
1630 let scale =
1632 cfg.max_position_embeddings as f64 / cfg.original_max_position_embeddings as f64;
1633 let scaling_factor = if scale <= 1.0 {
1634 1.0
1635 } else {
1636 (1.0 + scale.ln() / (cfg.original_max_position_embeddings as f64).ln()).sqrt()
1637 };
1638
1639 let inv_freq_short: Vec<_> = (0..dim)
1641 .step_by(2)
1642 .enumerate()
1643 .map(|(k, i)| {
1644 1f32 / (short_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32
1645 })
1646 .collect();
1647 let inv_freq_len_short = inv_freq_short.len();
1648 let inv_freq_short = Tensor::from_vec(inv_freq_short, (1, inv_freq_len_short), dev)?;
1649 let t_short = Tensor::arange(0u32, max_seq_len as u32, dev)?
1650 .to_dtype(DType::F32)?
1651 .reshape((max_seq_len, 1))?;
1652 let freqs_short = t_short.matmul(&inv_freq_short)?;
1653 let sin_short = (freqs_short.sin()?.to_dtype(dtype)? * scaling_factor)?;
1654 let cos_short = (freqs_short.cos()?.to_dtype(dtype)? * scaling_factor)?;
1655
1656 let inv_freq_long: Vec<_> = (0..dim)
1658 .step_by(2)
1659 .enumerate()
1660 .map(|(k, i)| {
1661 1f32 / (long_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32
1662 })
1663 .collect();
1664 let inv_freq_len_long = inv_freq_long.len();
1665 let inv_freq_long = Tensor::from_vec(inv_freq_long, (1, inv_freq_len_long), dev)?;
1666 let t_long = Tensor::arange(0u32, max_seq_len as u32, dev)?
1667 .to_dtype(DType::F32)?
1668 .reshape((max_seq_len, 1))?;
1669 let freqs_long = t_long.matmul(&inv_freq_long)?;
1670 let sin_long = (freqs_long.sin()?.to_dtype(dtype)? * scaling_factor)?;
1671 let cos_long = (freqs_long.cos()?.to_dtype(dtype)? * scaling_factor)?;
1672
1673 Ok(Self {
1674 short_cos: cos_short,
1675 short_sin: sin_short,
1676 long_cos: Some(cos_long),
1677 long_sin: Some(sin_long),
1678 original_max_position_embeddings: cfg.original_max_position_embeddings,
1679 })
1680 }
1681
1682 pub fn new(dtype: DType, cfg: &Phi4MMConfig, dev: &Device) -> Result<Self> {
1683 match &cfg.rope_scaling {
1684 Some(Phi4MMRopeScalingConfig {
1685 scaling_type: Phi4MMScaledRopeType::LongRope,
1686 short_factor: Some(short_factor),
1687 long_factor: Some(long_factor),
1688 }) => Self::new_longrope(short_factor, long_factor, cfg, dtype, dev),
1689
1690 _ => Self::new_unscaled(cfg, dtype, dev),
1691 }
1692 }
1693
1694 fn get_long_or_short_sin_cos(&self, position_ids: &[usize]) -> (&Tensor, &Tensor) {
1696 if self.long_cos.is_none() {
1697 return (&self.short_sin, &self.short_cos);
1698 }
1699 let seq_len = position_ids.iter().max().unwrap() + 1;
1700 if seq_len > self.original_max_position_embeddings {
1701 (
1702 self.long_sin.as_ref().unwrap(),
1703 self.long_cos.as_ref().unwrap(),
1704 )
1705 } else {
1706 (&self.short_sin, &self.short_cos)
1707 }
1708 }
1709
1710 pub fn forward(
1711 &self,
1712 q: &Tensor,
1713 k: &Tensor,
1714 seqlen_offsets: &[usize],
1715 position_ids: &[usize],
1716 ) -> Result<(Tensor, Tensor)> {
1717 let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
1718 let (sin, cos) = self.get_long_or_short_sin_cos(position_ids);
1719
1720 let rot_dim = cos.dim(D::Minus1)? * 2;
1721 let q_rot = q.narrow(D::Minus1, 0, rot_dim)?;
1722 let q_pass = q.narrow(D::Minus1, rot_dim, q.dim(D::Minus1)? - rot_dim)?;
1723 let k_rot = k.narrow(D::Minus1, 0, rot_dim)?;
1724 let k_pass = k.narrow(D::Minus1, rot_dim, k.dim(D::Minus1)? - rot_dim)?;
1725
1726 let (q_rot, k_rot) = if seqlen_offsets.len() == 1 {
1727 let cos = cos.narrow(0, seqlen_offsets[0], seq_len)?;
1728 let sin = sin.narrow(0, seqlen_offsets[0], seq_len)?;
1729 let q_embed = candle_nn::rotary_emb::rope(&q_rot.contiguous()?, &cos, &sin)?;
1730 let k_embed = candle_nn::rotary_emb::rope(&k_rot.contiguous()?, &cos, &sin)?;
1731 (q_embed, k_embed)
1732 } else {
1733 let mut q_embeds = Vec::new();
1734 let mut k_embeds = Vec::new();
1735 for (i, offset) in seqlen_offsets.iter().enumerate() {
1736 let cos = cos.narrow(0, *offset, seq_len)?;
1737 let sin = sin.narrow(0, *offset, seq_len)?;
1738 let q_embed = candle_nn::rotary_emb::rope(
1739 &q_rot.i(i)?.unsqueeze(0)?.contiguous()?,
1740 &cos,
1741 &sin,
1742 )?;
1743 let k_embed = candle_nn::rotary_emb::rope(
1744 &k_rot.i(i)?.unsqueeze(0)?.contiguous()?,
1745 &cos,
1746 &sin,
1747 )?;
1748 q_embeds.push(q_embed);
1749 k_embeds.push(k_embed);
1750 }
1751 let q_rot = Tensor::cat(&q_embeds, 0)?;
1752 let k_rot = Tensor::cat(&k_embeds, 0)?;
1753 (q_rot, k_rot)
1754 };
1755
1756 Ok((
1757 Tensor::cat(&[q_rot, q_pass], D::Minus1)?.contiguous()?,
1758 Tensor::cat(&[k_rot, k_pass], D::Minus1)?.contiguous()?,
1759 ))
1760 }
1761}
1762
1763#[derive(Debug, Clone)]
1764pub struct Gemma3nRotaryEmbedding(RotaryEmbedding);
1765
1766#[derive(Debug, Clone, Deserialize, Serialize)]
1767#[serde(rename_all = "lowercase")]
1768pub enum Gemma3nScaledRopeType {
1769 #[serde(alias = "linear")]
1770 Linear,
1771}
1772
1773#[derive(Debug, Clone, Deserialize, Serialize)]
1774pub struct Gemma3nRopeScalingConfig {
1775 factor: f64,
1776 rope_type: Gemma3nScaledRopeType,
1777}
1778
1779impl Gemma3nRotaryEmbedding {
1780 fn new_linear(
1781 cfg: &Gemma3nTextConfig,
1782 factor: f64,
1783 is_gpt_neox: bool,
1784 dtype: DType,
1785 dev: &Device,
1786 ) -> Result<Self> {
1787 let max_seq_len = cfg.max_position_embeddings;
1788 let dim = cfg.head_dim;
1789
1790 let inv_freq: Vec<_> = (0..dim)
1791 .step_by(2)
1792 .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
1793 .collect();
1794 let inv_freq_len = inv_freq.len();
1795 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
1796 let inv_freq = (inv_freq / factor)?;
1797
1798 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
1799 .to_dtype(DType::F32)?
1800 .reshape((max_seq_len, 1))?;
1801 let freqs = t.matmul(&inv_freq)?;
1802 let sin = freqs.sin()?.to_dtype(dtype)?;
1803 let cos = freqs.cos()?.to_dtype(dtype)?;
1804 Ok(Self(RotaryEmbedding {
1805 cos,
1806 sin,
1807 is_gpt_neox,
1808 }))
1809 }
1810
1811 pub fn new(
1812 is_gpt_neox: bool,
1813 dtype: DType,
1814 cfg: &Gemma3nTextConfig,
1815 dev: &Device,
1816 ) -> Result<Self> {
1817 match &cfg.rope_scaling {
1818 Some(Gemma3RopeScalingConfig {
1819 rope_type: Gemma3ScaledRopeType::Linear,
1820 factor,
1821 }) => Self::new_linear(cfg, *factor, is_gpt_neox, dtype, dev),
1822
1823 _ => Self::new_linear(cfg, 1.0, is_gpt_neox, dtype, dev),
1824 }
1825 }
1826
1827 pub fn get_cos_sin(&self) -> Result<(Tensor, Tensor)> {
1828 self.0.get_cos_sin()
1829 }
1830
1831 pub fn forward(
1832 &self,
1833 q: &Tensor,
1834 k: &Tensor,
1835 seqlen_offsets: &[usize],
1836 ) -> Result<(Tensor, Tensor)> {
1837 self.0.forward(q, k, seqlen_offsets)
1838 }
1839}
1840
1841#[derive(Debug, Clone)]
1842pub struct Gemma3RotaryEmbedding(RotaryEmbedding);
1843
1844#[derive(Debug, Clone, Deserialize, Serialize)]
1845#[serde(rename_all = "lowercase")]
1846pub enum Gemma3ScaledRopeType {
1847 #[serde(alias = "linear")]
1848 Linear,
1849}
1850
1851#[derive(Debug, Clone, Deserialize, Serialize)]
1852pub struct Gemma3RopeScalingConfig {
1853 factor: f64,
1854 rope_type: Gemma3ScaledRopeType,
1855}
1856
1857impl Gemma3RotaryEmbedding {
1858 fn new_linear(
1859 cfg: &Gemma3TextConfig,
1860 factor: f64,
1861 is_gpt_neox: bool,
1862 dtype: DType,
1863 dev: &Device,
1864 ) -> Result<Self> {
1865 let max_seq_len = cfg.max_position_embeddings;
1866 let dim = cfg.head_dim;
1867
1868 let inv_freq: Vec<_> = (0..dim)
1869 .step_by(2)
1870 .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
1871 .collect();
1872 let inv_freq_len = inv_freq.len();
1873 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
1874 let inv_freq = (inv_freq / factor)?;
1875
1876 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
1877 .to_dtype(DType::F32)?
1878 .reshape((max_seq_len, 1))?;
1879 let freqs = t.matmul(&inv_freq)?;
1880 let sin = freqs.sin()?.to_dtype(dtype)?;
1881 let cos = freqs.cos()?.to_dtype(dtype)?;
1882 Ok(Self(RotaryEmbedding {
1883 cos,
1884 sin,
1885 is_gpt_neox,
1886 }))
1887 }
1888
1889 pub fn new(
1890 is_gpt_neox: bool,
1891 dtype: DType,
1892 cfg: &Gemma3TextConfig,
1893 dev: &Device,
1894 ) -> Result<Self> {
1895 match &cfg.rope_scaling {
1896 Some(Gemma3RopeScalingConfig {
1897 rope_type: Gemma3ScaledRopeType::Linear,
1898 factor,
1899 }) => Self::new_linear(cfg, *factor, is_gpt_neox, dtype, dev),
1900
1901 _ => Self::new_linear(cfg, 1.0, is_gpt_neox, dtype, dev),
1902 }
1903 }
1904
1905 fn new_linear_embedding_gemma(
1906 cfg: &EmbeddingGemmaConfig,
1907 factor: f64,
1908 is_gpt_neox: bool,
1909 dtype: DType,
1910 dev: &Device,
1911 ) -> Result<Self> {
1912 let max_seq_len = cfg.max_position_embeddings;
1913 let dim = cfg.head_dim;
1914
1915 let inv_freq: Vec<_> = (0..dim)
1916 .step_by(2)
1917 .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
1918 .collect();
1919 let inv_freq_len = inv_freq.len();
1920 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
1921 let inv_freq = (inv_freq / factor)?;
1922
1923 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
1924 .to_dtype(DType::F32)?
1925 .reshape((max_seq_len, 1))?;
1926 let freqs = t.matmul(&inv_freq)?;
1927 let sin = freqs.sin()?.to_dtype(dtype)?;
1928 let cos = freqs.cos()?.to_dtype(dtype)?;
1929 Ok(Self(RotaryEmbedding {
1930 cos,
1931 sin,
1932 is_gpt_neox,
1933 }))
1934 }
1935
1936 pub fn new_embedding_gemma(
1937 is_gpt_neox: bool,
1938 dtype: DType,
1939 cfg: &EmbeddingGemmaConfig,
1940 dev: &Device,
1941 ) -> Result<Self> {
1942 match &cfg.rope_scaling {
1943 Some(Gemma3RopeScalingConfig {
1944 rope_type: Gemma3ScaledRopeType::Linear,
1945 factor,
1946 }) => Self::new_linear_embedding_gemma(cfg, *factor, is_gpt_neox, dtype, dev),
1947
1948 _ => Self::new_linear_embedding_gemma(cfg, 1.0, is_gpt_neox, dtype, dev),
1949 }
1950 }
1951
1952 pub fn forward(
1953 &self,
1954 q: &Tensor,
1955 k: &Tensor,
1956 seqlen_offsets: &[usize],
1957 ) -> Result<(Tensor, Tensor)> {
1958 self.0.forward(q, k, seqlen_offsets)
1959 }
1960}
1961
1962pub struct DiaRotaryEmbedding {
1963 timescale: Tensor,
1964 dtype: DType,
1965}
1966
1967impl DiaRotaryEmbedding {
1968 pub fn new(
1969 min_timescale: f32,
1970 max_timescale: f32,
1971 head_dim: usize,
1972 device: &Device,
1973 dtype: DType,
1974 ) -> Result<Self> {
1975 assert_eq!(head_dim % 2, 0);
1976 let half_embedding_dim = head_dim / 2;
1977
1978 let fraction = (0..half_embedding_dim).map(|i| 2f32 * i as f32 / head_dim as f32);
1979 let timescale = fraction
1980 .into_iter()
1981 .map(|x| min_timescale * (max_timescale / min_timescale).powf(x))
1982 .collect::<Vec<_>>();
1983
1984 let timescale_len = timescale.len();
1985 let timescale = Tensor::from_vec(timescale, timescale_len, device)?;
1986
1987 Ok(Self { timescale, dtype })
1988 }
1989
1990 pub fn forward(&self, xs: &Tensor, positions: &Tensor) -> Result<Tensor> {
1991 let freqs = positions
1992 .unsqueeze(D::Minus1)?
1993 .unsqueeze(D::Minus1)?
1994 .broadcast_div(&self.timescale)?;
1995
1996 let sin = freqs.sin()?.to_dtype(self.dtype)?;
1997 let cos = freqs.cos()?.to_dtype(self.dtype)?;
1998
1999 let split = xs.chunk(2, D::Minus1)?;
2000 let first_half = &split[0];
2001 let second_half = &split[1];
2002
2003 let first_part = (first_half.broadcast_mul(&cos)? - second_half.broadcast_mul(&sin)?)?;
2004 let second_part = (second_half.broadcast_mul(&cos)? + first_half.broadcast_mul(&sin)?)?;
2005
2006 Tensor::cat(&[first_part, second_part], D::Minus1)
2007 }
2008}
2009#[derive(Debug, Clone)]
2010pub struct QLinear {
2011 inner: QMatMul,
2012 bias: Option<Tensor>,
2013 dtype: DType,
2014}
2015
2016impl QLinear {
2017 pub fn new<R: std::io::Read + std::io::Seek>(
2018 ct: &mut Content<'_, R>,
2019 name: &str,
2020 device: &Device,
2021 ) -> Result<Self> {
2022 let w = ct.tensor(&format!("{name}.weight"), device)?;
2023 let b = ct.tensor(&format!("{name}.bias"), device)?;
2024 let inner = QMatMul::from_qtensor(w)?;
2025 let bias = b.dequantize(device)?;
2026 Ok(Self {
2027 inner,
2028 bias: Some(bias),
2029 dtype: DType::F32,
2030 })
2031 }
2032
2033 pub fn from_linear(linear: Linear) -> Self {
2034 Self {
2035 inner: QMatMul::Tensor(linear.weight().clone()),
2036 bias: linear.bias().cloned(),
2037 dtype: linear.weight().dtype(),
2038 }
2039 }
2040
2041 pub fn from_parts(w: Tensor, b: Option<Tensor>) -> Self {
2042 let dtype = w.dtype();
2043 Self {
2044 inner: QMatMul::Tensor(w),
2045 bias: b,
2046 dtype,
2047 }
2048 }
2049
2050 pub fn from_qparts(w: QTensor, b: Option<Tensor>) -> Self {
2051 if let Some(ref b) = b {
2052 assert_eq!(b.dtype(), DType::F32);
2053 }
2054 Self {
2055 inner: QMatMul::QTensor(Arc::new(w)),
2056 bias: b,
2057 dtype: DType::F32,
2058 }
2059 }
2060
2061 pub fn from_old_and_qmatmul(inner: QMatMul, old: &Self) -> Self {
2062 Self {
2063 inner,
2064 bias: old.bias.clone(),
2065 dtype: old.dtype,
2066 }
2067 }
2068
2069 pub fn inner(&mut self) -> &mut QMatMul {
2070 &mut self.inner
2071 }
2072
2073 pub fn inner_ref(&self) -> &QMatMul {
2074 &self.inner
2075 }
2076
2077 pub fn is_quant(&self) -> bool {
2078 matches!(self.inner, QMatMul::QTensor(_))
2079 }
2080
2081 pub fn bias(&self) -> Option<&Tensor> {
2082 self.bias.as_ref()
2083 }
2084
2085 pub fn bias_mut(&mut self) -> Option<&mut Tensor> {
2086 self.bias.as_mut()
2087 }
2088}
2089
2090impl Module for QLinear {
2091 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2092 let xs = if self.is_quant() {
2093 xs.to_dtype(DType::F32)?
2094 } else {
2095 xs.clone()
2096 };
2097 if let Some(bias) = &self.bias {
2098 self.inner
2099 .forward(&xs)?
2100 .broadcast_add(bias)?
2101 .to_dtype(self.dtype)
2102 } else {
2103 self.inner.forward(&xs)?.to_dtype(self.dtype)
2104 }
2105 }
2106}
2107
2108#[derive(Debug, Clone)]
2109pub struct RotaryEmbedding {
2110 cos: Tensor,
2111 sin: Tensor,
2112 is_gpt_neox: bool,
2113}
2114
2115impl RotaryEmbedding {
2116 pub fn new(
2117 base: f32,
2118 head_dim: usize,
2119 max_position_embeddings: usize,
2120 device: &Device,
2121 is_gpt_neox: bool,
2122 dtype: DType,
2123 ) -> Result<Self> {
2124 let inv_freq: Vec<_> = (0..head_dim)
2125 .step_by(2)
2126 .map(|i| 1f32 / base.powf(i as f32 / head_dim as f32))
2127 .collect();
2128 let inv_freq_len = inv_freq.len();
2129 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), device)?;
2130 let t = Tensor::arange(0u32, max_position_embeddings as u32, device)?
2131 .to_dtype(DType::F32)?
2132 .reshape((max_position_embeddings, 1))?;
2133 let freqs = t.matmul(&inv_freq)?;
2134 let sin = freqs.sin()?.to_dtype(dtype)?;
2135 let cos = freqs.cos()?.to_dtype(dtype)?;
2136
2137 Ok(Self {
2138 cos,
2139 sin,
2140 is_gpt_neox,
2141 })
2142 }
2143
2144 pub fn get_cos_sin(&self) -> Result<(Tensor, Tensor)> {
2145 Ok((self.cos.clone(), self.sin.clone()))
2146 }
2147
2148 pub fn new_partial(
2149 base: f32,
2150 rot_dim: usize,
2151 max_position_embeddings: usize,
2152 device: &Device,
2153 is_gpt_neox: bool,
2154 dtype: DType,
2155 ) -> Result<Self> {
2156 let inv_freq: Vec<_> = (0..rot_dim)
2157 .step_by(2)
2158 .map(|i| 1f32 / base.powf(i as f32 / rot_dim as f32))
2159 .collect();
2160 let inv_freq_len = inv_freq.len();
2161 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), device)?;
2162 let t = Tensor::arange(0u32, max_position_embeddings as u32, device)?
2163 .to_dtype(DType::F32)?
2164 .reshape((max_position_embeddings, 1))?;
2165 let freqs = t.matmul(&inv_freq)?;
2166 let sin = freqs.sin()?.to_dtype(dtype)?;
2167 let cos = freqs.cos()?.to_dtype(dtype)?;
2168
2169 Ok(Self {
2170 cos,
2171 sin,
2172 is_gpt_neox,
2173 })
2174 }
2175
2176 pub fn forward(
2177 &self,
2178 q: &Tensor,
2179 k: &Tensor,
2180 seqlen_offsets: &[usize],
2181 ) -> Result<(Tensor, Tensor)> {
2182 let (b_sz, qh, seq_len, n_embd) = q.dims4()?;
2183 let (_b_sz, kh, _seq_len, __n_embd) = k.dims4()?;
2184
2185 let rope = if self.is_gpt_neox {
2186 candle_nn::rotary_emb::rope
2187 } else {
2188 candle_nn::rotary_emb::rope_i
2189 };
2190
2191 if cfg!(feature = "cuda") && qh == kh {
2192 let (cos, sin) = if seqlen_offsets.len() == 1 {
2193 (
2194 self.cos.narrow(0, seqlen_offsets[0], seq_len)?,
2195 self.sin.narrow(0, seqlen_offsets[0], seq_len)?,
2196 )
2197 } else {
2198 let mut cos_s = Vec::new();
2199 let mut sin_s = Vec::new();
2200 for offset in seqlen_offsets {
2201 cos_s.push(self.cos.narrow(0, *offset, seq_len)?);
2202 sin_s.push(self.sin.narrow(0, *offset, seq_len)?);
2203 }
2204 (Tensor::cat(&cos_s, 0)?, Tensor::cat(&sin_s, 0)?)
2205 };
2206
2207 let q_embed = q.transpose(1, 2)?.flatten(0, 1)?;
2208 let k_embed = k.transpose(1, 2)?.flatten(0, 1)?;
2209 mistralrs_quant::rotary::apply_rotary_inplace(
2210 &q_embed,
2211 &k_embed,
2212 &cos,
2213 &sin,
2214 self.is_gpt_neox,
2215 )?;
2216 let mut q = q_embed
2217 .reshape((b_sz, seq_len, qh, n_embd))?
2218 .transpose(1, 2)?;
2219 let mut k = k_embed
2220 .reshape((b_sz, seq_len, kh, n_embd))?
2221 .transpose(1, 2)?;
2222 if !(cfg!(feature = "flash-attn") || cfg!(feature = "flash-attn-v3")) {
2223 q = q.contiguous()?;
2224 k = k.contiguous()?;
2225 }
2226 Ok((q, k))
2227 } else if seqlen_offsets.len() == 1 {
2228 let cos = self.cos.narrow(0, seqlen_offsets[0], seq_len)?;
2229 let sin = self.sin.narrow(0, seqlen_offsets[0], seq_len)?;
2230 let q_embed = rope(&q.contiguous()?, &cos, &sin)?;
2231 let k_embed = rope(&k.contiguous()?, &cos, &sin)?;
2232 Ok((q_embed, k_embed))
2233 } else {
2234 let mut q_embeds = Vec::new();
2235 let mut k_embeds = Vec::new();
2236 for (i, offset) in seqlen_offsets.iter().enumerate() {
2237 let cos = self.cos.narrow(0, *offset, seq_len)?;
2238 let sin = self.sin.narrow(0, *offset, seq_len)?;
2239 let q_embed = rope(&q.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
2240 let k_embed = rope(&k.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
2241 q_embeds.push(q_embed);
2242 k_embeds.push(k_embed);
2243 }
2244 Ok((Tensor::cat(&q_embeds, 0)?, Tensor::cat(&k_embeds, 0)?))
2245 }
2246 }
2247}
2248
2249#[derive(Debug, Clone)]
2252pub struct GptOssRotaryEmbedding {
2253 cos: Tensor,
2254 sin: Tensor,
2255 #[allow(dead_code)]
2256 attention_scale: f32,
2257}
2258
2259impl GptOssRotaryEmbedding {
2260 #[allow(clippy::too_many_arguments)]
2274 pub fn new(
2275 base: f64,
2276 head_dim: usize,
2277 max_position_embeddings: usize,
2278 factor: f64,
2279 original_max_position_embeddings: usize,
2280 beta_fast: f64,
2281 beta_slow: f64,
2282 truncate: bool,
2283 device: &Device,
2284 dtype: DType,
2285 ) -> Result<Self> {
2286 let dim = head_dim;
2287
2288 let attention_scale = (0.1 * factor.ln() + 1.0) as f32;
2290
2291 let find_correction_dim = |num_rotations: f64| -> f64 {
2294 (dim as f64
2295 * (original_max_position_embeddings as f64
2296 / (num_rotations * 2.0 * std::f64::consts::PI))
2297 .ln())
2298 / (2.0 * base.ln())
2299 };
2300
2301 let mut low = find_correction_dim(beta_fast);
2303 let mut high = find_correction_dim(beta_slow);
2304 if truncate {
2305 low = low.floor();
2306 high = high.ceil();
2307 }
2308 low = low.max(0.0);
2309 high = high.min((dim - 1) as f64);
2310
2311 let half_dim = dim / 2;
2313 let inv_freq_extrapolation: Vec<f64> = (0..dim)
2314 .step_by(2)
2315 .map(|i| 1.0 / base.powf(i as f64 / dim as f64))
2316 .collect();
2317 let inv_freq_interpolation: Vec<f64> =
2318 inv_freq_extrapolation.iter().map(|f| f / factor).collect();
2319
2320 let inv_freq: Vec<f64> = (0..half_dim)
2322 .map(|i| {
2323 let range = if (high - low).abs() < 0.001 {
2324 0.001
2325 } else {
2326 high - low
2327 };
2328 let linear = (i as f64 - low) / range;
2329 let ramp = linear.clamp(0.0, 1.0);
2330 inv_freq_interpolation[i] * ramp + inv_freq_extrapolation[i] * (1.0 - ramp)
2331 })
2332 .collect();
2333
2334 let inv_freq_len = inv_freq.len();
2335 let inv_freq_tensor = Tensor::from_vec(
2336 inv_freq.iter().map(|&x| x as f32).collect::<Vec<_>>(),
2337 (1, inv_freq_len),
2338 device,
2339 )?;
2340
2341 let t = Tensor::arange(0u32, max_position_embeddings as u32, device)?
2342 .to_dtype(DType::F32)?
2343 .reshape((max_position_embeddings, 1))?;
2344
2345 let freqs = t.matmul(&inv_freq_tensor)?;
2346
2347 let sin = (freqs.sin()? * attention_scale as f64)?.to_dtype(dtype)?;
2349 let cos = (freqs.cos()? * attention_scale as f64)?.to_dtype(dtype)?;
2350
2351 Ok(Self {
2352 cos,
2353 sin,
2354 attention_scale,
2355 })
2356 }
2357
2358 pub fn forward(
2359 &self,
2360 q: &Tensor,
2361 k: &Tensor,
2362 seqlen_offsets: &[usize],
2363 ) -> Result<(Tensor, Tensor)> {
2364 #[allow(unused_variables)]
2365 let (b_sz, qh, seq_len, n_embd) = q.dims4()?;
2366 #[allow(unused_variables)]
2367 let (_b_sz, kh, _seq_len, _n_embd) = k.dims4()?;
2368
2369 #[cfg(feature = "cuda")]
2372 if q.device().is_cuda() && qh == k.dim(1)? {
2373 let (cos, sin) = if seqlen_offsets.len() == 1 {
2374 (
2375 self.cos.narrow(0, seqlen_offsets[0], seq_len)?,
2376 self.sin.narrow(0, seqlen_offsets[0], seq_len)?,
2377 )
2378 } else {
2379 let mut cos_s = Vec::new();
2380 let mut sin_s = Vec::new();
2381 for offset in seqlen_offsets {
2382 cos_s.push(self.cos.narrow(0, *offset, seq_len)?);
2383 sin_s.push(self.sin.narrow(0, *offset, seq_len)?);
2384 }
2385 (Tensor::cat(&cos_s, 0)?, Tensor::cat(&sin_s, 0)?)
2386 };
2387
2388 let q_embed = q.transpose(1, 2)?.flatten(0, 1)?;
2390 let k_embed = k.transpose(1, 2)?.flatten(0, 1)?;
2391
2392 mistralrs_quant::rotary::apply_rotary_inplace(&q_embed, &k_embed, &cos, &sin, true)?;
2394
2395 let mut q = q_embed
2397 .reshape((b_sz, seq_len, qh, n_embd))?
2398 .transpose(1, 2)?;
2399 let mut k = k_embed
2400 .reshape((b_sz, seq_len, kh, n_embd))?
2401 .transpose(1, 2)?;
2402
2403 if !(cfg!(feature = "flash-attn") || cfg!(feature = "flash-attn-v3")) {
2404 q = q.contiguous()?;
2405 k = k.contiguous()?;
2406 }
2407 return Ok((q, k));
2408 }
2409
2410 if seqlen_offsets.len() == 1 {
2412 let cos = self.cos.narrow(0, seqlen_offsets[0], seq_len)?;
2413 let sin = self.sin.narrow(0, seqlen_offsets[0], seq_len)?;
2414 let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
2415 let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
2416 Ok((q_embed, k_embed))
2417 } else {
2418 let mut q_embeds = Vec::new();
2419 let mut k_embeds = Vec::new();
2420 for (i, offset) in seqlen_offsets.iter().enumerate() {
2421 let cos = self.cos.narrow(0, *offset, seq_len)?;
2422 let sin = self.sin.narrow(0, *offset, seq_len)?;
2423 let q_embed =
2424 candle_nn::rotary_emb::rope(&q.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
2425 let k_embed =
2426 candle_nn::rotary_emb::rope(&k.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
2427 q_embeds.push(q_embed);
2428 k_embeds.push(k_embed);
2429 }
2430 Ok((Tensor::cat(&q_embeds, 0)?, Tensor::cat(&k_embeds, 0)?))
2431 }
2432 }
2433}
2434
2435#[derive(Debug, Clone, Copy, PartialEq, Deserialize, Serialize, Default)]
2436#[serde(rename_all = "lowercase")]
2437pub enum Activation {
2438 #[default]
2439 #[serde(alias = "gelu")]
2440 Gelu,
2441 #[serde(alias = "gelu_new")]
2442 NewGelu,
2443 Relu,
2444 Relu2,
2445 Relu6,
2446 Silu,
2447 Sigmoid,
2448 HardSigmoid,
2449 Swiglu,
2450 Swish,
2451 HardSwish,
2452 Elu(f64),
2453 LeakyRelu(f64),
2454 #[serde(alias = "gelu_pytorch_tanh")]
2455 GeluPytorchTanh,
2456 QuickGelu,
2457}
2458
2459impl Module for Activation {
2460 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2461 match self {
2462 Self::Gelu => xs.gelu_erf(),
2463 Self::NewGelu => xs.gelu(),
2465 Self::Relu => xs.relu(),
2466 Self::Relu2 => xs.relu()?.sqr(),
2467 Self::Relu6 => xs.clamp(0f32, 6f32),
2468 Self::Silu => xs.silu(),
2469 Self::Sigmoid => candle_nn::ops::sigmoid(xs),
2470 Self::HardSigmoid => candle_nn::ops::hard_sigmoid(xs),
2471 Self::Swiglu => candle_nn::ops::swiglu(xs),
2472 Self::Swish => xs * candle_nn::ops::sigmoid(xs)?,
2473 Self::HardSwish => xs * candle_nn::ops::hard_sigmoid(xs)?,
2474 &Self::Elu(alpha) => xs.elu(alpha),
2475 &Self::LeakyRelu(negative_slope) => candle_nn::ops::leaky_relu(xs, negative_slope),
2476 Self::GeluPytorchTanh => xs.gelu(),
2477 Self::QuickGelu => xs * candle_nn::ops::sigmoid(&(xs * 1.702f64)?),
2478 }
2479 }
2480}
2481
2482impl TryInto<candle_nn::Activation> for Activation {
2483 type Error = candle_core::Error;
2484
2485 fn try_into(self) -> Result<candle_nn::Activation> {
2486 match self {
2487 Self::Gelu => Ok(candle_nn::Activation::Gelu),
2488 Self::Relu => Ok(candle_nn::Activation::Relu),
2489 Self::Silu => Ok(candle_nn::Activation::Silu),
2490 Self::NewGelu => Ok(candle_nn::Activation::NewGelu),
2491 Self::Relu2 => Ok(candle_nn::Activation::Relu2),
2492 Self::Relu6 => Ok(candle_nn::Activation::Relu6),
2493 Self::Sigmoid => Ok(candle_nn::Activation::Sigmoid),
2494 Self::HardSigmoid => Ok(candle_nn::Activation::HardSigmoid),
2495 Self::Swiglu => Ok(candle_nn::Activation::Swiglu),
2496 Self::Swish => Ok(candle_nn::Activation::Swish),
2497 Self::HardSwish => Ok(candle_nn::Activation::HardSwish),
2498 Self::Elu(x) => Ok(candle_nn::Activation::Elu(x)),
2499 Self::LeakyRelu(x) => Ok(candle_nn::Activation::LeakyRelu(x)),
2500 Self::GeluPytorchTanh => Ok(candle_nn::Activation::GeluPytorchTanh),
2501 Self::QuickGelu => candle_core::bail!("No mapping to candle_nn for QuickGelu"),
2502 }
2503 }
2504}
2505
2506#[derive(Debug, Clone, Copy, PartialEq, Eq)]
2507pub struct Conv3dConfig {
2508 pub padding: usize,
2509 pub stride: usize,
2510 pub dilation: usize,
2511 pub groups: usize,
2512}
2513
2514impl Default for Conv3dConfig {
2515 fn default() -> Self {
2516 Self {
2517 padding: 0,
2518 stride: 1,
2519 dilation: 1,
2520 groups: 1,
2521 }
2522 }
2523}
2524
2525pub struct Conv3dNoBias {
2526 conv2d_1: Conv2d,
2527 conv2d_2: Conv2d,
2528}
2529
2530impl Conv3dNoBias {
2531 pub fn new(
2532 in_channels: usize,
2533 out_channels: usize,
2534 kernel_sizes: [usize; 3],
2535 cfg: Conv3dConfig,
2536 vb: ShardedVarBuilder,
2537 ) -> Result<Self> {
2538 let expected_shape = (
2539 out_channels,
2540 in_channels / cfg.groups,
2541 kernel_sizes[0],
2542 kernel_sizes[1],
2543 kernel_sizes[2],
2544 );
2545 let mlx_shape = (
2548 out_channels,
2549 kernel_sizes[0],
2550 kernel_sizes[1],
2551 kernel_sizes[2],
2552 in_channels / cfg.groups,
2553 );
2554 let ws = if vb.contains_tensor("weight") {
2555 match vb.get(expected_shape, "weight") {
2557 Ok(ws) => ws,
2558 Err(_) => {
2559 let ws = vb.get(mlx_shape, "weight")?;
2561 ws.permute((0, 4, 1, 2, 3))?
2562 }
2563 }
2564 } else {
2565 vb.get(expected_shape, "weight")?
2566 };
2567
2568 let w1 = ws.i((.., .., 0, .., ..))?;
2572 let w2 = ws.i((.., .., 1, .., ..))?;
2573
2574 let cfg = Conv2dConfig {
2575 padding: cfg.padding,
2576 stride: cfg.stride,
2577 dilation: cfg.dilation,
2578 groups: cfg.groups,
2579 cudnn_fwd_algo: None,
2580 };
2581
2582 Ok(Self {
2583 conv2d_1: Conv2d::new(w1.contiguous()?, None, cfg),
2584 conv2d_2: Conv2d::new(w2.contiguous()?, None, cfg),
2585 })
2586 }
2587
2588 pub fn weight(&self) -> Result<Tensor> {
2589 let w1 = self.conv2d_1.weight().clone().unsqueeze(2)?;
2590 let w2 = self.conv2d_2.weight().clone().unsqueeze(2)?;
2591 Tensor::cat(&[w1, w2], 2)
2592 }
2593}
2594
2595impl Module for Conv3dNoBias {
2596 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2597 let xs1 = xs.i((.., .., 0, .., ..))?;
2598 let xs2 = xs.i((.., .., 1, .., ..))?;
2599
2600 (Convolution.forward_2d(&self.conv2d_1, &xs1)?
2601 + Convolution.forward_2d(&self.conv2d_2, &xs2)?)?
2602 .unsqueeze(2)
2603 }
2604}
2605
2606pub trait TensorInfExtend {
2607 fn is_inf(&self) -> Result<Self>
2608 where
2609 Self: Sized;
2610 fn any(&self) -> Result<bool>;
2611}
2612
2613impl TensorInfExtend for Tensor {
2614 fn is_inf(&self) -> Result<Self> {
2615 self.broadcast_eq(&Tensor::new(f32::INFINITY, self.device())?.to_dtype(self.dtype())?)
2616 }
2617
2618 fn any(&self) -> Result<bool> {
2619 let sum = self.sum_all()?;
2620 match self.dtype() {
2621 DType::U8 => Ok(sum.to_scalar::<u8>()? == 0),
2622 DType::U32 => Ok(sum.to_scalar::<u32>()? == 0),
2623 DType::I16 => Ok(sum.to_scalar::<i16>()? == 0),
2624 DType::I32 => Ok(sum.to_scalar::<i32>()? == 0),
2625 DType::I64 => Ok(sum.to_scalar::<i64>()? == 0),
2626 DType::F16 => Ok(sum.to_scalar::<half::f16>()? == half::f16::from_f32_const(0.)),
2627 DType::BF16 => Ok(sum.to_scalar::<half::bf16>()? == half::bf16::from_f32_const(0.)),
2628 DType::F32 => Ok(sum.to_scalar::<f32>()? == 0.),
2629 DType::F64 => Ok(sum.to_scalar::<f64>()? == 0.),
2630 DType::F8E4M3 => Ok(sum.to_scalar::<F8E4M3>()? == F8E4M3::ZERO),
2631 DType::F4 | DType::F6E3M2 | DType::F6E2M3 | DType::F8E8M0 => {
2632 candle_core::bail!("f4/f6e3m2/f6e2m3/f8e8m0 tensors are not supported with .any")
2633 }
2634 }
2635 }
2636}
2637
2638pub fn clamp_for_f16(xs: &Tensor) -> Result<Tensor> {
2639 let mut max = match xs.dtype() {
2640 DType::U8 => u8::MAX as f32 - 1000.,
2641 DType::U32 => u32::MAX as f32 - 1000.,
2642 DType::I16 => i16::MAX as f32 - 1000.,
2643 DType::I32 => i32::MAX as f32 - 1000.,
2644 DType::I64 => i64::MAX as f32 - 1000.,
2645 DType::F16 => half::f16::MAX.to_f32_const() - 1000.,
2646 DType::BF16 => half::bf16::MAX.to_f32_const() - 1000.,
2647 DType::F32 => f32::MAX - 1000.,
2648 DType::F64 => f64::MAX as f32 - 1000.,
2649 DType::F8E4M3 => F8E4M3::MAX.to_f32() - 1000.,
2650 DType::F4 | DType::F6E3M2 | DType::F6E2M3 | DType::F8E8M0 => {
2651 candle_core::bail!("f4/f6e3m2/f6e2m3/f8e8m0 tensors are not supported with .any")
2652 }
2653 };
2654 if xs.is_inf()?.any()? {
2655 max -= 1000.;
2656 }
2657 xs.clamp(-max, max)
2658}
2659
2660pub struct FloatInfo {
2661 pub min: f64,
2663 pub max: f64,
2665 pub eps: f64,
2667 pub dtype: DType,
2668}
2669
2670pub trait GetFloatInfo {
2671 fn finfo(&self) -> Result<FloatInfo>;
2672}
2673
2674impl GetFloatInfo for DType {
2675 fn finfo(&self) -> Result<FloatInfo> {
2676 let finfo = match self {
2677 Self::BF16 => FloatInfo {
2678 min: bf16::MIN.to_f64(),
2679 max: bf16::MAX.to_f64(),
2680 eps: bf16::EPSILON.to_f64(),
2681 dtype: DType::BF16,
2682 },
2683 Self::F16 => FloatInfo {
2684 min: f16::MIN.to_f64(),
2685 max: f16::MAX.to_f64(),
2686 eps: f16::EPSILON.to_f64(),
2687 dtype: DType::F16,
2688 },
2689 Self::F32 => FloatInfo {
2690 min: f32::MIN as f64,
2691 max: f32::MAX as f64,
2692 eps: f32::EPSILON as f64,
2693 dtype: DType::F32,
2694 },
2695 Self::F64 => FloatInfo {
2696 min: f64::MIN,
2697 max: f64::MAX,
2698 eps: f64::EPSILON,
2699 dtype: DType::F64,
2700 },
2701 Self::F8E4M3 => FloatInfo {
2702 min: F8E4M3::MIN.to_f64(),
2703 max: F8E4M3::MAX.to_f64(),
2704 eps: F8E4M3::EPSILON.to_f64(),
2705 dtype: DType::F8E4M3,
2706 },
2707 other => {
2708 candle_core::bail!("Expected a float type for `GetFloatInfo`, got {other:?}");
2709 }
2710 };
2711 Ok(finfo)
2712 }
2713}
2714
2715#[derive(Clone)]
2716pub struct Mlp {
2717 pub gate: Arc<dyn QuantMethod>,
2718 pub up: Arc<dyn QuantMethod>,
2719 pub down: Arc<dyn QuantMethod>,
2720 act: Activation,
2721 params: Vec<usize>,
2722}
2723
2724impl Mlp {
2725 pub fn new(
2726 vb: ShardedVarBuilder,
2727 hidden_size: usize,
2728 intermediate_size: usize,
2729 quantization_config: &Option<QuantizedConfig>,
2730 hidden_act: Activation,
2731 comm: &Arc<mistralrs_quant::Comm>,
2732 ) -> Result<Self> {
2733 Ok(Self {
2734 gate: ColumnParallelLayer::new(
2735 hidden_size,
2736 intermediate_size,
2737 quantization_config,
2738 false,
2739 comm,
2740 vb.pp("gate_proj"),
2741 )?,
2742 up: ColumnParallelLayer::new(
2743 hidden_size,
2744 intermediate_size,
2745 quantization_config,
2746 false,
2747 comm,
2748 vb.pp("up_proj"),
2749 )?,
2750 down: RowParallelLayer::new(
2751 intermediate_size,
2752 hidden_size,
2753 quantization_config,
2754 false,
2755 comm,
2756 vb.pp("down_proj"),
2757 )?,
2758 act: hidden_act,
2759 params: vec![hidden_size, intermediate_size],
2760 })
2761 }
2762
2763 pub fn new_merged(
2764 vb: ShardedVarBuilder,
2765 hidden_size: usize,
2766 intermediate_size: usize,
2767 chunks: usize,
2768 quantization_config: &Option<QuantizedConfig>,
2769 hidden_act: Activation,
2770 comm: &Arc<mistralrs_quant::Comm>,
2771 ) -> Result<Self> {
2772 assert!(chunks == 2, "Only gate_up_proj merge is supported!");
2773 let gate_up_projs = ColumnParallelLayer::new_merged(
2774 hidden_size,
2775 intermediate_size * 2,
2776 2,
2777 quantization_config,
2778 false,
2779 comm,
2780 vb.pp("gate_up_proj"),
2781 )?;
2782
2783 Ok(Self {
2784 gate: gate_up_projs[0].to_owned(),
2785 up: gate_up_projs[1].to_owned(),
2786 down: RowParallelLayer::new(
2787 intermediate_size,
2788 hidden_size,
2789 quantization_config,
2790 false,
2791 comm,
2792 vb.pp("down_proj"),
2793 )?,
2794 act: hidden_act,
2795 params: vec![hidden_size, intermediate_size],
2796 })
2797 }
2798
2799 pub fn replicate(
2800 params: &[usize],
2801 vb: ShardedVarBuilder,
2802 act: Activation,
2803 comm: &Arc<mistralrs_quant::Comm>,
2804 ) -> Result<Self> {
2805 Self::new(vb, params[0], params[1], &None, act, comm)
2806 }
2807
2808 pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2809 let original_dtype = xs.dtype();
2810 let mut xs = xs.clone();
2811 if let Some(t) = self.gate.quantized_act_type() {
2812 xs = xs.to_dtype(t)?;
2813 }
2814 let lhs = self.gate.forward(&xs)?;
2815 let rhs = self.up.forward(&xs)?;
2816 let mut res = self
2817 .down
2818 .forward(&crate::ops::mul_and_act(&lhs, &rhs, self.act)?)?;
2819 if self.gate.quantized_act_type().is_some() {
2820 res = res.to_dtype(original_dtype)?;
2821 }
2822 Ok(res)
2823 }
2824}
2825
2826impl AnyMoeTrainableLayer for Mlp {}
2827
2828impl MlpLayer for Mlp {
2829 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2830 let original_dtype = xs.dtype();
2831 let mut xs = xs.clone();
2832 if let Some(t) = self.gate.quantized_act_type() {
2833 xs = xs.to_dtype(t)?;
2834 }
2835 let lhs = MatMul.qmethod_matmul(&xs, &*self.gate)?;
2836 let rhs = MatMul.qmethod_matmul(&xs, &*self.up)?;
2837 let mut res =
2838 MatMul.qmethod_matmul(&crate::ops::mul_and_act(&lhs, &rhs, self.act)?, &*self.down)?;
2839 if self.gate.quantized_act_type().is_some() {
2840 res = res.to_dtype(original_dtype)?;
2841 }
2842 Ok(res)
2843 }
2844 fn get_isq_layers(&mut self) -> Vec<&mut Arc<dyn QuantMethod>> {
2845 vec![&mut self.gate, &mut self.up, &mut self.down]
2846 }
2847 fn clone(&self) -> Box<dyn MlpLayer> {
2848 Box::new(Clone::clone(self))
2849 }
2850 fn get_params(&self) -> &[usize] {
2851 &self.params
2852 }
2853 fn hidden_act(&self) -> Activation {
2854 self.act
2855 }
2856 fn new_added_delta(&self, deltas: Vec<Option<Tensor>>) -> Result<Box<dyn MlpLayer>> {
2858 let gate = if let Some(ref delta) = deltas[0] {
2859 self.gate.add_delta_w(delta)?
2860 } else {
2861 self.gate.clone()
2862 };
2863 let up = if let Some(ref delta) = deltas[1] {
2864 self.up.add_delta_w(delta)?
2865 } else {
2866 self.up.clone()
2867 };
2868 let down = if let Some(ref delta) = deltas[2] {
2869 self.down.add_delta_w(delta)?
2870 } else {
2871 self.down.clone()
2872 };
2873
2874 Ok(Box::new(Self {
2875 gate,
2876 up,
2877 down,
2878 act: self.act,
2879 params: self.params.clone(),
2880 }))
2881 }
2882
2883 fn dtype_device(&self) -> (DType, Device) {
2884 self.gate.dtype_and_device()
2885 }
2886}
2887
2888pub struct AvgPool2d {
2889 kernel_size: usize,
2890 stride: usize,
2891}
2892
2893impl AvgPool2d {
2894 pub fn new(kernel_size: usize, stride: usize) -> Self {
2895 Self {
2896 kernel_size,
2897 stride,
2898 }
2899 }
2900
2901 pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2902 xs.avg_pool2d_with_stride(self.kernel_size, self.stride)
2903 }
2904}
2905
2906pub struct ReflectionPad2d {
2913 padding: (usize, usize, usize, usize),
2914}
2915
2916impl ReflectionPad2d {
2917 pub fn new(padding: (usize, usize, usize, usize)) -> Self {
2918 Self { padding }
2919 }
2920}
2921
2922impl Module for ReflectionPad2d {
2923 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2924 let (pad_left, pad_right, pad_top, pad_bottom) = self.padding;
2925
2926 let (_n, _c, h, w) = xs.dims4()?;
2927
2928 let left_pad = if pad_left > 0 {
2931 let indices: Vec<i64> = (1..=pad_left as i64).rev().collect();
2933 Some(xs.index_select(&Tensor::new(indices, &Device::Cpu)?, 3)?)
2934 } else {
2935 None
2936 };
2937
2938 let right_pad = if pad_right > 0 {
2940 let start = w as i64 - 2;
2942 let indices: Vec<i64> = (0..pad_right as i64).map(|i| start - i).collect();
2943 Some(xs.index_select(&Tensor::new(indices, &Device::Cpu)?, 3)?)
2944 } else {
2945 None
2946 };
2947
2948 let x_padded_width = match (left_pad, right_pad) {
2950 (Some(l), Some(r)) => Tensor::cat(&[l, xs.clone(), r], 3)?,
2951 (Some(l), None) => Tensor::cat(&[l, xs.clone()], 3)?,
2952 (None, Some(r)) => Tensor::cat(&[xs.clone(), r], 3)?,
2953 (None, None) => xs.clone(),
2954 };
2955
2956 let top_pad = if pad_top > 0 {
2959 let indices: Vec<i64> = (1..=pad_top as i64).rev().collect();
2960 Some(x_padded_width.index_select(&Tensor::new(indices, &Device::Cpu)?, 2)?)
2961 } else {
2962 None
2963 };
2964
2965 let bottom_pad = if pad_bottom > 0 {
2967 let start = h as i64 - 2;
2968 let indices: Vec<i64> = (0..pad_bottom as i64).map(|i| start - i).collect();
2969 Some(x_padded_width.index_select(&Tensor::new(indices, &Device::Cpu)?, 2)?)
2970 } else {
2971 None
2972 };
2973
2974 let x_padded = match (top_pad, bottom_pad) {
2976 (Some(t), Some(b)) => Tensor::cat(&[t, x_padded_width, b], 2)?,
2977 (Some(t), None) => Tensor::cat(&[t, x_padded_width], 2)?,
2978 (None, Some(b)) => Tensor::cat(&[x_padded_width, b], 2)?,
2979 (None, None) => x_padded_width,
2980 };
2981
2982 Ok(x_padded)
2983 }
2984}
2985
2986pub struct ScaledEmbedding {
2987 scale: f64,
2988 pub embedding: Tensor,
2989}
2990
2991impl ScaledEmbedding {
2992 pub fn new(scale: f64, embedding: Embedding) -> Self {
2993 Self {
2994 scale,
2995 embedding: embedding.embeddings().clone(),
2996 }
2997 }
2998
2999 pub fn embeddings(&self) -> &Tensor {
3000 &self.embedding
3001 }
3002}
3003
3004impl Module for ScaledEmbedding {
3005 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
3006 let embedding = Embedding::new(self.embedding.clone(), self.embedding.dim(D::Minus1)?);
3007 xs.apply(&embedding)? * self.scale
3008 }
3009}