1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{f32::consts::PI, ops::Mul, str::FromStr, sync::Arc};
4
5use candle_core::{
6 quantized::{QMatMul, QTensor},
7 Context, DType, Device, IndexOp, Result, Tensor, D,
8};
9use candle_nn::{
10 BatchNorm, BatchNormConfig, Conv1d, Conv1dConfig, Conv2d, Conv2dConfig, Embedding, GroupNorm,
11 LayerNorm, LayerNormConfig, Linear, Module,
12};
13use float8::F8E4M3;
14use half::{bf16, f16};
15use mistralrs_quant::{
16 AfqLayer, ColumnParallelLayer, Convolution, QuantMethod, QuantizedConfig, RowParallelLayer,
17 ShardedVarBuilder,
18};
19use serde::{Deserialize, Serialize};
20
21pub use crate::attention::Sdpa;
22pub use crate::layers_masker::CausalMasker;
23pub use crate::layers_utils::repeat_kv;
24use crate::{
25 amoe::{AnyMoeTrainableLayer, MlpLayer},
26 embedding_models::embedding_gemma::EmbeddingGemmaConfig,
27 gguf::Content,
28 models::{llama, smollm3},
29 ops::SplitOp,
30 vision_models::{
31 gemma3::config::Gemma3TextConfig,
32 gemma3n::config::Gemma3nTextConfig,
33 llama4,
34 mllama::{MLlamaRopeScaling, MLlamaRopeType, MLlamaTextConfig},
35 phi4::Phi4MMConfig,
36 },
37};
38
39pub use mistralrs_quant::MatMul;
40
41pub fn embedding(
42 in_size: usize,
43 out_size: usize,
44 vb: ShardedVarBuilder,
45 config: &Option<QuantizedConfig>,
46) -> Result<Embedding> {
47 let embeddings = if let Some(QuantizedConfig::Afq { .. }) = config {
49 let afq_layer =
50 AfqLayer::afq_linear_b(out_size, in_size, config.as_ref().unwrap(), false, vb)?;
51 afq_layer.dequantize_w()?
52 } else {
53 vb.get_with_hints((in_size, out_size), "weight", Default::default())?
54 };
55 Ok(Embedding::new(embeddings, out_size))
56}
57
58pub fn layer_norm<C: Into<LayerNormConfig>>(
59 size: usize,
60 config: C,
61 vb: ShardedVarBuilder,
62) -> Result<LayerNorm> {
63 let config = config.into();
64 let weight = vb.get(size, "weight")?;
65 if config.affine {
66 let bias = vb.get(size, "bias")?;
67 Ok(LayerNorm::new(weight, bias, config.eps))
68 } else {
69 Ok(LayerNorm::new_no_bias(weight, config.eps))
70 }
71}
72
73pub fn batch_norm<C: Into<BatchNormConfig>>(
74 num_features: usize,
75 config: C,
76 vb: ShardedVarBuilder,
77) -> Result<BatchNorm> {
78 let config = config.into();
79 if config.eps < 0. {
80 candle_core::bail!("batch-norm eps cannot be negative {}", config.eps)
81 }
82 let running_mean = vb.get(num_features, "running_mean")?;
83 let running_var = vb.get(num_features, "running_var")?;
84
85 if config.affine {
86 let weight = vb.get(num_features, "weight")?;
87 let bias = vb.get(num_features, "bias")?;
88 BatchNorm::new(
89 num_features,
90 running_mean,
91 running_var,
92 weight,
93 bias,
94 config.eps,
95 )
96 } else {
97 BatchNorm::new_no_bias(num_features, running_mean, running_var, config.eps)
98 }
99}
100
101pub fn group_norm(
102 num_groups: usize,
103 num_channels: usize,
104 eps: f64,
105 vb: ShardedVarBuilder,
106) -> Result<GroupNorm> {
107 let weight = vb.get(num_channels, "weight")?;
108 let bias = vb.get(num_channels, "bias")?;
109 GroupNorm::new(weight, bias, num_channels, num_groups, eps)
110}
111
112pub fn conv2d(
113 in_channels: usize,
114 out_channels: usize,
115 kernel_size: usize,
116 cfg: Conv2dConfig,
117 vb: ShardedVarBuilder,
118) -> Result<Conv2d> {
119 let ws = vb.get(
120 (
121 out_channels,
122 in_channels / cfg.groups,
123 kernel_size,
124 kernel_size,
125 ),
126 "weight",
127 )?;
128 let bs = vb.get(out_channels, "bias")?;
129 Ok(Conv2d::new(ws, Some(bs), cfg))
130}
131
132pub fn conv2d_no_bias(
133 in_channels: usize,
134 out_channels: usize,
135 kernel_size: usize,
136 cfg: Conv2dConfig,
137 vb: ShardedVarBuilder,
138) -> Result<Conv2d> {
139 let ws = vb.get(
140 (
141 out_channels,
142 in_channels / cfg.groups,
143 kernel_size,
144 kernel_size,
145 ),
146 "weight",
147 )?;
148 Ok(Conv2d::new(ws, None, cfg))
149}
150
151pub fn conv1d(
152 in_channels: usize,
153 out_channels: usize,
154 kernel_size: usize,
155 cfg: Conv1dConfig,
156 vb: ShardedVarBuilder,
157) -> Result<Conv1d> {
158 let ws = vb.get(
159 (out_channels, in_channels / cfg.groups, kernel_size),
160 "weight",
161 )?;
162 let bs = vb.get(out_channels, "bias")?;
163 Ok(Conv1d::new(ws, Some(bs), cfg))
164}
165
166pub fn conv1d_no_bias(
167 in_channels: usize,
168 out_channels: usize,
169 kernel_size: usize,
170 cfg: Conv1dConfig,
171 vb: ShardedVarBuilder,
172) -> Result<Conv1d> {
173 let ws = vb.get(
174 (out_channels, in_channels / cfg.groups, kernel_size),
175 "weight",
176 )?;
177 Ok(Conv1d::new(ws, None, cfg))
178}
179
180pub fn linear(in_dim: usize, out_dim: usize, vb: ShardedVarBuilder) -> Result<Linear> {
181 let ws = vb.get((out_dim, in_dim), "weight")?;
182 let bs = vb.get(out_dim, "bias")?;
183 Ok(Linear::new(ws, Some(bs)))
184}
185
186pub fn linear_no_bias(in_dim: usize, out_dim: usize, vb: ShardedVarBuilder) -> Result<Linear> {
187 let ws = vb.get((out_dim, in_dim), "weight")?;
188 Ok(Linear::new(ws, None))
189}
190
191pub fn linear_b(
192 in_dim: usize,
193 out_dim: usize,
194 bias: bool,
195 vb: ShardedVarBuilder,
196) -> Result<Linear> {
197 if bias {
198 linear(in_dim, out_dim, vb)
199 } else {
200 linear_no_bias(in_dim, out_dim, vb)
201 }
202}
203
204#[derive(Debug, Clone)]
205pub struct RmsNorm {
206 eps: f64,
207 weight: Tensor,
208}
209
210impl RmsNorm {
211 pub fn new(size: usize, eps: f64, vb: ShardedVarBuilder) -> Result<Self> {
212 let w = vb.get(size, "weight")?;
213 Ok(Self { eps, weight: w })
214 }
215
216 pub fn new_gemma(size: usize, eps: f64, vb: ShardedVarBuilder) -> Result<Self> {
218 let w = vb.get(size, "weight")?;
219 let w = (w + 1.0)?;
220 Ok(Self { eps, weight: w })
221 }
222
223 pub fn new_gemma_3n(
225 size: usize,
226 eps: f64,
227 with_scale: bool,
228 vb: ShardedVarBuilder,
229 ) -> Result<Self> {
230 let w = if with_scale {
231 vb.get(size, "weight")?
232 } else {
233 Tensor::ones(size, vb.dtype(), vb.device())?
234 };
235 Ok(Self { eps, weight: w })
236 }
237
238 pub fn undo_gemma(&self) -> Result<Self> {
240 Ok(Self {
241 eps: self.eps,
242 weight: (&self.weight - 1.0)?,
243 })
244 }
245
246 pub fn from_w(w: Tensor, eps: f64) -> Result<Self> {
247 Ok(Self { eps, weight: w })
248 }
249
250 pub fn weight(&self) -> &Tensor {
251 &self.weight
252 }
253}
254
255impl Module for RmsNorm {
256 fn forward(&self, x: &Tensor) -> Result<Tensor> {
257 candle_nn::ops::rms_norm(&x.contiguous()?, &self.weight, self.eps as f32)
258 }
259}
260
261#[derive(Debug, Clone)]
262pub struct F32RmsNorm {
263 w: Tensor,
264 eps: f64,
265}
266
267impl F32RmsNorm {
268 pub fn new(size: usize, eps: f64, vb: ShardedVarBuilder) -> Result<Self> {
269 Ok(Self {
270 w: vb.get((size,), "weight")?,
271 eps,
272 })
273 }
274
275 pub fn weight(&self) -> &Tensor {
276 &self.w
277 }
278}
279
280impl Module for F32RmsNorm {
281 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
282 let initial_type = xs.dtype();
283 let mut xs = xs.to_dtype(DType::F32)?;
284 let var = xs.powf(2.)?.mean_keepdim(D::Minus1)?;
285 xs = xs.broadcast_mul(&(&var + self.eps)?.recip()?.sqrt()?)?;
286 xs.to_dtype(initial_type)?.broadcast_mul(&self.w)
287 }
288}
289
290#[derive(Debug, Clone)]
291pub struct QRmsNorm {
292 eps: f64,
293 weight: Tensor,
294}
295
296impl QRmsNorm {
297 pub fn new(scale: QTensor, eps: f32) -> Result<Self> {
298 let scale = scale.dequantize(&scale.device())?;
299 Ok(Self {
300 eps: eps as f64,
301 weight: scale,
302 })
303 }
304
305 pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
306 candle_nn::ops::rms_norm(&x.contiguous()?, &self.weight, self.eps as f32)
307 }
308}
309
310#[derive(Debug, Clone)]
312pub struct PhiRotaryEmbedding {
313 short_sin: Tensor,
314 short_cos: Tensor,
315 long_cos: Option<Tensor>,
316 long_sin: Option<Tensor>,
317 original_max_position_embeddings: usize,
318}
319
320#[derive(Debug, Clone, Deserialize, Serialize)]
321#[serde(rename_all = "lowercase")]
322pub enum ScaledRopeType {
323 #[serde(alias = "su")]
324 #[serde(alias = "longrope")]
325 Su,
326 #[serde(alias = "yarn")]
327 Yarn,
328 #[serde(alias = "dynamic")]
329 Dynamic,
330 #[serde(alias = "linear")]
331 Linear,
332}
333
334impl FromStr for ScaledRopeType {
335 type Err = candle_core::Error;
336 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
337 match s {
338 "su" | "longrope" => Ok(Self::Su),
339 "yarn" => Ok(Self::Yarn),
340 "linear" => Ok(Self::Linear),
341 "dynamic" => Ok(Self::Dynamic),
342 _ => Err(candle_core::Error::Msg(
343 "Expected either `su` or `yarn` scaled RoPE type.".to_string(),
344 )),
345 }
346 }
347}
348
349#[derive(Debug, Clone, Deserialize, Serialize)]
350#[serde(untagged)]
351pub enum PhiRopeScalingConfig {
352 Classic {
353 short_factor: Vec<f64>,
354 long_factor: Vec<f64>,
355 #[serde(rename = "type")]
356 scaling_type: ScaledRopeType,
357 },
358 Scaled {
359 short_factor: Vec<f64>,
360 long_factor: Vec<f64>,
361 #[serde(rename = "type")]
362 scaling_type: ScaledRopeType,
363 long_mscale: f64,
364 short_mscale: f64,
365 },
366}
367
368pub struct PhiRopeConfig {
369 pub rope_scaling: Option<PhiRopeScalingConfig>,
370 pub max_position_embeddings: usize,
371 pub original_max_position_embeddings: usize,
372 pub rope_theta: f64,
373 pub head_dim: usize,
374 pub partial_rotary_factor: Option<f64>,
375}
376
377impl PhiRotaryEmbedding {
378 fn new_classic_scaled(
379 short_factor: &[f64],
380 long_factor: &[f64],
381 scaling_type: &ScaledRopeType,
382 cfg: &PhiRopeConfig,
383 dtype: DType,
384 dev: &Device,
385 ) -> Result<Self> {
386 let max_seq_len = cfg.max_position_embeddings;
387 let dim = (cfg.head_dim as f64 * cfg.partial_rotary_factor.unwrap_or(1.)) as usize;
388
389 let scale =
391 cfg.max_position_embeddings as f64 / cfg.original_max_position_embeddings as f64;
392 let scaling_factor = if scale <= 1.0 {
393 1.0
394 } else {
395 match scaling_type {
396 ScaledRopeType::Su => {
397 (1.0 + scale.ln() / (cfg.original_max_position_embeddings as f64).ln()).sqrt()
398 }
399 ScaledRopeType::Yarn => 0.1 * scale.ln() + 1.0,
400 _ => candle_core::bail!("Expected either `su` or `yarn` RoPE"),
401 }
402 };
403
404 let inv_freq_long = (0..dim)
406 .step_by(2)
407 .enumerate()
408 .map(|(k, i)| {
409 (1f64 / (long_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64))) as f32
410 })
411 .collect::<Vec<_>>();
412 let inv_freq_short = (0..dim)
413 .step_by(2)
414 .enumerate()
415 .map(|(k, i)| {
416 (1f64 / (short_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64))) as f32
417 })
418 .collect::<Vec<_>>();
419 let inv_freq_len = inv_freq_long.len();
420
421 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
422 .to_dtype(DType::F32)?
423 .reshape((max_seq_len, 1))?;
424
425 let inv_freq_long = Tensor::from_vec(inv_freq_long, (1, inv_freq_len), dev)?;
427 let freqs_long = t.matmul(&inv_freq_long)?;
428 let long_sin = freqs_long.sin()?.mul(scaling_factor)?.to_dtype(dtype)?;
429 let long_cos = freqs_long.cos()?.mul(scaling_factor)?.to_dtype(dtype)?;
430
431 let inv_freq_short =
433 Tensor::from_vec(inv_freq_short, (1, inv_freq_len), dev)?.to_dtype(DType::F32)?;
434 let freqs_short = t.matmul(&inv_freq_short)?;
435 let short_sin = freqs_short.sin()?.mul(scaling_factor)?.to_dtype(dtype)?;
436 let short_cos = freqs_short.cos()?.mul(scaling_factor)?.to_dtype(dtype)?;
437
438 Ok(Self {
439 short_cos,
440 short_sin,
441 long_cos: Some(long_cos),
442 long_sin: Some(long_sin),
443 original_max_position_embeddings: cfg.original_max_position_embeddings,
444 })
445 }
446
447 fn new_unscaled(cfg: &PhiRopeConfig, dtype: DType, dev: &Device) -> Result<Self> {
448 let max_seq_len = cfg.max_position_embeddings;
449 let dim = (cfg.head_dim as f64 * cfg.partial_rotary_factor.unwrap_or(1.)) as usize;
450
451 let inv_freq: Vec<_> = (0..dim)
452 .step_by(2)
453 .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
454 .collect();
455 let inv_freq_len = inv_freq.len();
456 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
457 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
458 .to_dtype(DType::F32)?
459 .reshape((max_seq_len, 1))?;
460 let freqs = t.matmul(&inv_freq)?;
461 let sin = freqs.sin()?.to_dtype(dtype)?;
462 let cos = freqs.cos()?.to_dtype(dtype)?;
463 Ok(Self {
464 short_cos: cos,
465 short_sin: sin,
466 long_cos: None,
467 long_sin: None,
468 original_max_position_embeddings: cfg.original_max_position_embeddings,
469 })
470 }
471
472 #[allow(clippy::too_many_arguments)]
473 fn new_scaled(
474 short_factor: &[f64],
475 long_factor: &[f64],
476 scaling_type: &ScaledRopeType,
477 long_mscale: f64,
478 short_mscale: f64,
479 cfg: &PhiRopeConfig,
480 dtype: DType,
481 dev: &Device,
482 ) -> Result<Self> {
483 let max_seq_len = cfg.max_position_embeddings;
484 let dim = (cfg.head_dim as f64 * cfg.partial_rotary_factor.unwrap_or(1.)) as usize;
485
486 if !matches!(scaling_type, ScaledRopeType::Su) {
487 candle_core::bail!("Scaled Phi3 RoPE (non-classic scaled, with mscales) must have type `su`/`longrope`.");
488 }
489
490 if short_factor.len() != dim / 2 {
491 candle_core::bail!(
492 "Misaligned length {}, expected {} for `su`/`longrope` short rescale factors",
493 short_factor.len(),
494 dim / 2
495 );
496 }
497 if long_factor.len() != dim / 2 {
498 candle_core::bail!(
499 "Misaligned length {}, expected {} for `su`/`longrope` long rescale factors",
500 long_factor.len(),
501 dim / 2
502 );
503 }
504
505 let inv_freq_short: Vec<_> = (0..dim)
507 .step_by(2)
508 .enumerate()
509 .map(|(k, i)| {
510 1f32 / (short_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32
511 })
512 .collect();
513 let inv_freq_len_short = inv_freq_short.len();
514 let inv_freq_short = Tensor::from_vec(inv_freq_short, (1, inv_freq_len_short), dev)?;
515 let t_short = Tensor::arange(0u32, max_seq_len as u32, dev)?
516 .to_dtype(DType::F32)?
517 .reshape((max_seq_len, 1))?;
518 let freqs_short = t_short.matmul(&inv_freq_short)?;
519 let sin_short = (freqs_short.sin()?.to_dtype(dtype)? * short_mscale)?;
520 let cos_short = (freqs_short.cos()?.to_dtype(dtype)? * short_mscale)?;
521
522 let inv_freq_long: Vec<_> = (0..dim)
524 .step_by(2)
525 .enumerate()
526 .map(|(k, i)| {
527 1f32 / (long_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32
528 })
529 .collect();
530 let inv_freq_len_long = inv_freq_long.len();
531 let inv_freq_long = Tensor::from_vec(inv_freq_long, (1, inv_freq_len_long), dev)?;
532 let t_long = Tensor::arange(0u32, max_seq_len as u32, dev)?
533 .to_dtype(DType::F32)?
534 .reshape((max_seq_len, 1))?;
535 let freqs_long = t_long.matmul(&inv_freq_long)?;
536 let sin_long = (freqs_long.sin()?.to_dtype(dtype)? * long_mscale)?;
537 let cos_long = (freqs_long.cos()?.to_dtype(dtype)? * long_mscale)?;
538 Ok(Self {
539 short_cos: cos_short,
540 short_sin: sin_short,
541 long_cos: Some(cos_long),
542 long_sin: Some(sin_long),
543 original_max_position_embeddings: cfg.original_max_position_embeddings,
544 })
545 }
546
547 pub fn new(dtype: DType, cfg: impl Into<PhiRopeConfig>, dev: &Device) -> Result<Self> {
548 let cfg: PhiRopeConfig = cfg.into();
549
550 match &cfg.rope_scaling {
551 Some(PhiRopeScalingConfig::Classic {
552 short_factor,
553 long_factor,
554 scaling_type,
555 }) => {
556 Self::new_classic_scaled(short_factor, long_factor, scaling_type, &cfg, dtype, dev)
557 }
558
559 Some(PhiRopeScalingConfig::Scaled {
560 short_factor,
561 long_factor,
562 scaling_type,
563 long_mscale,
564 short_mscale,
565 }) => Self::new_scaled(
566 short_factor,
567 long_factor,
568 scaling_type,
569 *long_mscale,
570 *short_mscale,
571 &cfg,
572 dtype,
573 dev,
574 ),
575
576 None => Self::new_unscaled(&cfg, dtype, dev),
577 }
578 }
579
580 fn get_long_or_short_sin_cos(&self, position_ids: &[usize]) -> (&Tensor, &Tensor) {
582 if self.long_cos.is_none() {
583 return (&self.short_sin, &self.short_cos);
584 }
585 let seq_len = position_ids.iter().max().unwrap() + 1;
586 if seq_len > self.original_max_position_embeddings {
587 (
588 self.long_sin.as_ref().unwrap(),
589 self.long_cos.as_ref().unwrap(),
590 )
591 } else {
592 (&self.short_sin, &self.short_cos)
593 }
594 }
595
596 pub fn forward(
597 &self,
598 q: &Tensor,
599 k: &Tensor,
600 seqlen_offsets: &[usize],
601 position_ids: &[usize],
602 ) -> Result<(Tensor, Tensor)> {
603 let (sin, cos) = self.get_long_or_short_sin_cos(position_ids);
604 let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
605
606 let rot_dim = cos.dim(D::Minus1)? * 2;
607
608 if rot_dim != q.dim(D::Minus1)? {
610 let rot_dim = cos.dim(D::Minus1)? * 2;
611 let q_rot = q.narrow(D::Minus1, 0, rot_dim)?;
612 let q_pass = q.narrow(D::Minus1, rot_dim, q.dim(D::Minus1)? - rot_dim)?;
613 let k_rot = k.narrow(D::Minus1, 0, rot_dim)?;
614 let k_pass = k.narrow(D::Minus1, rot_dim, k.dim(D::Minus1)? - rot_dim)?;
615
616 let (q_rot, k_rot) = if seqlen_offsets.len() == 1 {
617 let cos = cos.narrow(0, seqlen_offsets[0], seq_len)?;
618 let sin = sin.narrow(0, seqlen_offsets[0], seq_len)?;
619 let q_embed = candle_nn::rotary_emb::rope(&q_rot.contiguous()?, &cos, &sin)?;
620 let k_embed = candle_nn::rotary_emb::rope(&k_rot.contiguous()?, &cos, &sin)?;
621 (q_embed, k_embed)
622 } else {
623 let mut q_embeds = Vec::new();
624 let mut k_embeds = Vec::new();
625 for (i, offset) in seqlen_offsets.iter().enumerate() {
626 let cos = cos.narrow(0, *offset, seq_len)?;
627 let sin = sin.narrow(0, *offset, seq_len)?;
628 let q_embed = candle_nn::rotary_emb::rope(
629 &q_rot.i(i)?.unsqueeze(0)?.contiguous()?,
630 &cos,
631 &sin,
632 )?;
633 let k_embed = candle_nn::rotary_emb::rope(
634 &k_rot.i(i)?.unsqueeze(0)?.contiguous()?,
635 &cos,
636 &sin,
637 )?;
638 q_embeds.push(q_embed);
639 k_embeds.push(k_embed);
640 }
641 let q_rot = Tensor::cat(&q_embeds, 0)?;
642 let k_rot = Tensor::cat(&k_embeds, 0)?;
643 (q_rot, k_rot)
644 };
645
646 Ok((
647 Tensor::cat(&[q_rot, q_pass], D::Minus1)?.contiguous()?,
648 Tensor::cat(&[k_rot, k_pass], D::Minus1)?.contiguous()?,
649 ))
650 } else if seqlen_offsets.len() == 1 {
651 let cos = cos.narrow(0, seqlen_offsets[0], seq_len)?;
652 let sin = sin.narrow(0, seqlen_offsets[0], seq_len)?;
653 let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
654 let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
655 Ok((q_embed, k_embed))
656 } else {
657 let mut q_embeds = Vec::new();
658 let mut k_embeds = Vec::new();
659 for (i, offset) in seqlen_offsets.iter().enumerate() {
660 let cos = cos.narrow(0, *offset, seq_len)?;
661 let sin = sin.narrow(0, *offset, seq_len)?;
662 let q_embed =
663 candle_nn::rotary_emb::rope(&q.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
664 let k_embed =
665 candle_nn::rotary_emb::rope(&k.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
666 q_embeds.push(q_embed);
667 k_embeds.push(k_embed);
668 }
669 Ok((Tensor::cat(&q_embeds, 0)?, Tensor::cat(&k_embeds, 0)?))
670 }
671 }
672}
673
674#[derive(Debug, Clone)]
676pub struct Llama3RotaryEmbedding(RotaryEmbedding);
677
678#[derive(Debug, Clone, Deserialize, Serialize, Default)]
679pub enum Llama3RopeType {
680 #[serde(rename = "llama3")]
681 Llama3,
682 #[serde(rename = "linear")]
683 Linear,
684 #[default]
685 #[serde(rename = "default")]
686 Default,
687}
688
689#[derive(Debug, Clone, Deserialize, Serialize, Default)]
690pub struct Llama3RopeConfig {
691 pub factor: f32,
692 pub low_freq_factor: Option<f32>,
693 pub high_freq_factor: Option<f32>,
694 pub original_max_position_embeddings: Option<usize>,
695 pub rope_type: Llama3RopeType,
696}
697
698fn calculate_default_inv_freq(cfg: &llama::Config) -> Vec<f32> {
699 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
700 (0..head_dim)
701 .step_by(2)
702 .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
703 .collect()
704}
705
706fn calculate_default_inv_freq_llama4(cfg: &llama4::TextConfig) -> Vec<f32> {
707 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
708 (0..head_dim)
709 .step_by(2)
710 .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
711 .collect()
712}
713
714impl Llama3RotaryEmbedding {
716 pub fn new_llama3(
717 dtype: DType,
718 cfg: &llama::Config,
719 dev: &Device,
720 is_gpt_neox: bool,
721 ) -> Result<Self> {
722 match &cfg.rope_scaling {
723 None
724 | Some(Llama3RopeConfig {
725 rope_type: Llama3RopeType::Default,
726 ..
727 }) => Ok(Self(RotaryEmbedding::new(
728 cfg.rope_theta,
729 cfg.hidden_size / cfg.num_attention_heads,
730 cfg.max_position_embeddings,
731 dev,
732 is_gpt_neox,
733 dtype,
734 )?)),
735 Some(Llama3RopeConfig {
736 rope_type: Llama3RopeType::Llama3,
737 factor,
738 low_freq_factor,
739 high_freq_factor,
740 original_max_position_embeddings,
741 }) => {
742 let low_freq_factor = low_freq_factor.context("low_freq_factor is required")?;
743 let high_freq_factor = high_freq_factor.context("high_freq_factor is required")?;
744 let original_max_position_embeddings = original_max_position_embeddings
745 .context("original_max_position_embeddings is required")?;
746
747 let low_freq_wavelen = original_max_position_embeddings as f32 / low_freq_factor;
748 let high_freq_wavelen = original_max_position_embeddings as f32 / high_freq_factor;
749
750 let inv_freq = calculate_default_inv_freq(cfg)
751 .into_iter()
752 .map(|freq| {
753 let wavelen = 2. * PI / freq;
754 if wavelen < high_freq_wavelen {
755 freq
756 } else if wavelen > low_freq_wavelen {
757 freq / *factor
758 } else {
759 let smooth = (original_max_position_embeddings as f32 / wavelen
760 - low_freq_factor)
761 / (high_freq_factor - low_freq_factor);
762 (1. - smooth) * freq / *factor + smooth * freq
763 }
764 })
765 .collect::<Vec<_>>();
766 let inv_freq_len = inv_freq.len();
767 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
768 let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
769 .to_dtype(DType::F32)?
770 .reshape((cfg.max_position_embeddings, 1))?;
771 let freqs = t.matmul(&inv_freq)?;
772 let sin = freqs.sin()?.to_dtype(dtype)?;
773 let cos = freqs.cos()?.to_dtype(dtype)?;
774 Ok(Self(RotaryEmbedding {
775 sin,
776 cos,
777 is_gpt_neox,
778 }))
779 }
780 Some(Llama3RopeConfig {
781 rope_type: Llama3RopeType::Linear,
782 factor,
783 ..
784 }) => {
785 let inv_freq_vec = calculate_default_inv_freq(cfg)
786 .into_iter()
787 .map(|freq| freq / *factor)
788 .collect::<Vec<_>>();
789 let inv_freq_len = inv_freq_vec.len();
790 let inv_freq = Tensor::from_vec(inv_freq_vec, (1, inv_freq_len), dev)?;
791 let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
792 .to_dtype(DType::F32)?
793 .reshape((cfg.max_position_embeddings, 1))?;
794 let freqs = t.matmul(&inv_freq)?;
795 let sin = freqs.sin()?.to_dtype(dtype)?;
796 let cos = freqs.cos()?.to_dtype(dtype)?;
797 Ok(Self(RotaryEmbedding {
798 sin,
799 cos,
800 is_gpt_neox,
801 }))
802 }
803 }
804 }
805
806 pub fn new_llama4(
807 dtype: DType,
808 cfg: &llama4::TextConfig,
809 dev: &Device,
810 is_gpt_neox: bool,
811 ) -> Result<Self> {
812 match &cfg.rope_scaling {
813 None
814 | Some(Llama3RopeConfig {
815 rope_type: Llama3RopeType::Default,
816 ..
817 }) => Ok(Self(RotaryEmbedding::new(
818 cfg.rope_theta,
819 cfg.hidden_size / cfg.num_attention_heads,
820 cfg.max_position_embeddings,
821 dev,
822 is_gpt_neox,
823 dtype,
824 )?)),
825 Some(Llama3RopeConfig {
826 rope_type: Llama3RopeType::Llama3,
827 factor,
828 low_freq_factor,
829 high_freq_factor,
830 original_max_position_embeddings,
831 }) => {
832 let low_freq_factor = low_freq_factor.context("low_freq_factor is required")?;
833 let high_freq_factor = high_freq_factor.context("high_freq_factor is required")?;
834 let original_max_position_embeddings = original_max_position_embeddings
835 .context("original_max_position_embeddings is required")?;
836
837 let low_freq_wavelen = original_max_position_embeddings as f32 / low_freq_factor;
838 let high_freq_wavelen = original_max_position_embeddings as f32 / high_freq_factor;
839
840 let inv_freq = calculate_default_inv_freq_llama4(cfg)
841 .into_iter()
842 .map(|freq| {
843 let wavelen = 2. * PI / freq;
844 if wavelen < high_freq_wavelen {
845 freq
846 } else if wavelen > low_freq_wavelen {
847 freq / *factor
848 } else {
849 let smooth = (original_max_position_embeddings as f32 / wavelen
850 - low_freq_factor)
851 / (high_freq_factor - low_freq_factor);
852 (1. - smooth) * freq / *factor + smooth * freq
853 }
854 })
855 .collect::<Vec<_>>();
856 let inv_freq_len = inv_freq.len();
857 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
858 let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
859 .to_dtype(DType::F32)?
860 .reshape((cfg.max_position_embeddings, 1))?;
861 let freqs = t.matmul(&inv_freq)?;
862 let sin = freqs.sin()?.to_dtype(dtype)?;
863 let cos = freqs.cos()?.to_dtype(dtype)?;
864 Ok(Self(RotaryEmbedding {
865 sin,
866 cos,
867 is_gpt_neox,
868 }))
869 }
870 Some(Llama3RopeConfig {
871 rope_type: Llama3RopeType::Linear,
872 factor,
873 ..
874 }) => {
875 let inv_freq_vec = calculate_default_inv_freq_llama4(cfg)
876 .into_iter()
877 .map(|freq| freq / *factor)
878 .collect::<Vec<_>>();
879 let inv_freq_len = inv_freq_vec.len();
880 let inv_freq = Tensor::from_vec(inv_freq_vec, (1, inv_freq_len), dev)?;
881 let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
882 .to_dtype(DType::F32)?
883 .reshape((cfg.max_position_embeddings, 1))?;
884 let freqs = t.matmul(&inv_freq)?;
885 let sin = freqs.sin()?.to_dtype(dtype)?;
886 let cos = freqs.cos()?.to_dtype(dtype)?;
887 Ok(Self(RotaryEmbedding {
888 sin,
889 cos,
890 is_gpt_neox,
891 }))
892 }
893 }
894 }
895
896 pub fn new_mllama3(
897 dtype: DType,
898 cfg: &MLlamaTextConfig,
899 dev: &Device,
900 is_gpt_neox: bool,
901 ) -> Result<Self> {
902 match &cfg.rope_scaling {
903 None
904 | Some(MLlamaRopeScaling {
905 rope_type: MLlamaRopeType::Default,
906 ..
907 }) => Ok(Self(RotaryEmbedding::new(
908 cfg.rope_theta,
909 cfg.hidden_size / cfg.num_attention_heads,
910 cfg.max_position_embeddings,
911 dev,
912 is_gpt_neox,
913 dtype,
914 )?)),
915 Some(MLlamaRopeScaling {
916 rope_type: MLlamaRopeType::Llama3,
917 original_max_position_embeddings,
918 factor,
919 attention_factor: _,
920 beta_fast: _,
921 beta_slow: _,
922 short_factor: _,
923 long_factor: _,
924 low_freq_factor,
925 high_freq_factor,
926 }) => {
927 let factor = factor.context("MLlama Llama3 RoPE needs `factor` parameter.")?;
928 let low_freq_factor = low_freq_factor
929 .context("MLlama Llama3 RoPE needs `low_freq_factor` parameter.")?;
930 let high_freq_factor = high_freq_factor
931 .context("MLlama Llama3 RoPE needs `high_freq_factor` parameter.")?;
932
933 let low_freq_wavelen = *original_max_position_embeddings as f32 / low_freq_factor;
934 let high_freq_wavelen = *original_max_position_embeddings as f32 / high_freq_factor;
935
936 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
937
938 let inv_freq = (0..head_dim)
939 .step_by(2)
940 .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
941 .map(|freq| {
942 let wavelen = 2. * PI / freq;
943 if wavelen < high_freq_wavelen {
944 freq
945 } else if wavelen > low_freq_wavelen {
946 freq / factor
947 } else {
948 let smooth = (*original_max_position_embeddings as f32 / wavelen
949 - low_freq_factor)
950 / (high_freq_factor - low_freq_factor);
951 (1. - smooth) * freq / factor + smooth * freq
952 }
953 })
954 .collect::<Vec<_>>();
955 let inv_freq_len = inv_freq.len();
956 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
957
958 let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
959 .to_dtype(DType::F32)?
960 .reshape((cfg.max_position_embeddings, 1))?;
961 let freqs = t.matmul(&inv_freq)?;
962 let sin = freqs.sin()?.to_dtype(dtype)?;
963 let cos = freqs.cos()?.to_dtype(dtype)?;
964 Ok(Self(RotaryEmbedding {
965 sin,
966 cos,
967 is_gpt_neox,
968 }))
969 }
970 Some(MLlamaRopeScaling {
971 rope_type: other, ..
972 }) => {
973 candle_core::bail!(
974 "MLlama doesn't support any other RoPE type than `llama3`, got {other:?}"
975 )
976 }
977 }
978 }
979
980 pub fn forward(
981 &self,
982 q: &Tensor,
983 k: &Tensor,
984 seqlen_offsets: &[usize],
985 ) -> Result<(Tensor, Tensor)> {
986 self.0.forward(q, k, seqlen_offsets)
987 }
988}
989
990#[derive(Debug, Clone)]
992pub struct SmolLm3RotaryEmbedding(RotaryEmbedding);
993
994#[derive(Debug, Clone, Deserialize, Serialize, Default)]
995pub enum SmolLm3RopeType {
996 #[serde(rename = "llama3")]
997 Llama3,
998 #[serde(rename = "linear")]
999 Linear,
1000 #[default]
1001 #[serde(rename = "default")]
1002 Default,
1003}
1004
1005#[derive(Debug, Clone, Deserialize, Serialize, Default)]
1006pub struct SmolLm3RopeConfig {
1007 pub factor: f32,
1008 pub low_freq_factor: Option<f32>,
1009 pub high_freq_factor: Option<f32>,
1010 pub original_max_position_embeddings: Option<usize>,
1011 pub rope_type: SmolLm3RopeType,
1012}
1013
1014fn calculate_default_inv_freq_smollm3(cfg: &smollm3::Config) -> Vec<f32> {
1015 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
1016 (0..head_dim)
1017 .step_by(2)
1018 .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
1019 .collect()
1020}
1021
1022impl SmolLm3RotaryEmbedding {
1023 pub fn new_llama3(
1024 dtype: DType,
1025 cfg: &smollm3::Config,
1026 dev: &Device,
1027 is_gpt_neox: bool,
1028 ) -> Result<Self> {
1029 match &cfg.rope_scaling {
1030 None
1031 | Some(SmolLm3RopeConfig {
1032 rope_type: SmolLm3RopeType::Default,
1033 ..
1034 }) => Ok(Self(RotaryEmbedding::new(
1035 cfg.rope_theta,
1036 cfg.hidden_size / cfg.num_attention_heads,
1037 cfg.max_position_embeddings,
1038 dev,
1039 is_gpt_neox,
1040 dtype,
1041 )?)),
1042 Some(SmolLm3RopeConfig {
1043 rope_type: SmolLm3RopeType::Llama3,
1044 factor,
1045 low_freq_factor,
1046 high_freq_factor,
1047 original_max_position_embeddings,
1048 }) => {
1049 let low_freq_factor = low_freq_factor.context("low_freq_factor is required")?;
1050 let high_freq_factor = high_freq_factor.context("high_freq_factor is required")?;
1051 let original_max_position_embeddings = original_max_position_embeddings
1052 .context("original_max_position_embeddings is required")?;
1053
1054 let low_freq_wavelen = original_max_position_embeddings as f32 / low_freq_factor;
1055 let high_freq_wavelen = original_max_position_embeddings as f32 / high_freq_factor;
1056
1057 let inv_freq = calculate_default_inv_freq_smollm3(cfg)
1058 .into_iter()
1059 .map(|freq| {
1060 let wavelen = 2. * PI / freq;
1061 if wavelen < high_freq_wavelen {
1062 freq
1063 } else if wavelen > low_freq_wavelen {
1064 freq / *factor
1065 } else {
1066 let smooth = (original_max_position_embeddings as f32 / wavelen
1067 - low_freq_factor)
1068 / (high_freq_factor - low_freq_factor);
1069 (1. - smooth) * freq / *factor + smooth * freq
1070 }
1071 })
1072 .collect::<Vec<_>>();
1073 let inv_freq_len = inv_freq.len();
1074 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
1075 let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
1076 .to_dtype(DType::F32)?
1077 .reshape((cfg.max_position_embeddings, 1))?;
1078 let freqs = t.matmul(&inv_freq)?;
1079 let sin = freqs.sin()?.to_dtype(dtype)?;
1080 let cos = freqs.cos()?.to_dtype(dtype)?;
1081 Ok(Self(RotaryEmbedding {
1082 sin,
1083 cos,
1084 is_gpt_neox,
1085 }))
1086 }
1087 Some(SmolLm3RopeConfig {
1088 rope_type: SmolLm3RopeType::Linear,
1089 factor,
1090 ..
1091 }) => {
1092 let inv_freq_vec = calculate_default_inv_freq_smollm3(cfg)
1093 .into_iter()
1094 .map(|freq| freq / *factor)
1095 .collect::<Vec<_>>();
1096 let inv_freq_len = inv_freq_vec.len();
1097 let inv_freq = Tensor::from_vec(inv_freq_vec, (1, inv_freq_len), dev)?;
1098 let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
1099 .to_dtype(DType::F32)?
1100 .reshape((cfg.max_position_embeddings, 1))?;
1101 let freqs = t.matmul(&inv_freq)?;
1102 let sin = freqs.sin()?.to_dtype(dtype)?;
1103 let cos = freqs.cos()?.to_dtype(dtype)?;
1104 Ok(Self(RotaryEmbedding {
1105 sin,
1106 cos,
1107 is_gpt_neox,
1108 }))
1109 }
1110 }
1111 }
1112
1113 pub fn forward(
1114 &self,
1115 q: &Tensor,
1116 k: &Tensor,
1117 seqlen_offsets: &[usize],
1118 ) -> Result<(Tensor, Tensor)> {
1119 self.0.forward(q, k, seqlen_offsets)
1120 }
1121}
1122
1123#[derive(Debug, Clone)]
1125pub struct Qwen2VLRotaryEmbedding {
1126 inv_freq: Tensor,
1127 mrope_section: Vec<usize>,
1128}
1129
1130impl Qwen2VLRotaryEmbedding {
1131 pub fn new(
1132 base: f32,
1133 head_dim: usize,
1134 device: &Device,
1135 mrope_section: Vec<usize>,
1136 ) -> Result<Self> {
1137 let inv_freq: Vec<_> = (0..head_dim)
1138 .step_by(2)
1139 .map(|i| 1f32 / base.powf(i as f32 / head_dim as f32))
1140 .collect();
1141 let inv_freq_len = inv_freq.len();
1142 let inv_freq = Tensor::from_vec(inv_freq, (inv_freq_len,), device)?.to_dtype(DType::F32)?;
1143 Ok(Self {
1144 inv_freq,
1145 mrope_section,
1146 })
1147 }
1148
1149 pub fn compute_cos_sin(&self, position_ids: &Tensor, dtype: DType) -> Result<(Tensor, Tensor)> {
1151 let inv_freq_expanded =
1152 self.inv_freq
1153 .reshape((1, 1, (), 1))?
1154 .repeat((3, position_ids.dim(1)?, 1, 1))?;
1155 let position_ids_expanded = position_ids.unsqueeze(2)?;
1156 let freqs = inv_freq_expanded
1157 .matmul(&position_ids_expanded.to_dtype(inv_freq_expanded.dtype())?)?
1158 .transpose(2, 3)?;
1159 let cos = freqs.cos()?;
1160 let sin = freqs.sin()?;
1161
1162 let cos = Tensor::cat(
1163 &cos.split(&self.mrope_section, D::Minus1)?
1164 .into_iter()
1165 .enumerate()
1166 .map(|(i, m)| m.i(i % 3))
1167 .collect::<Result<Vec<_>>>()?,
1168 D::Minus1,
1169 )?
1170 .squeeze(0)?
1171 .to_dtype(dtype)?
1172 .contiguous()?;
1173 let sin = Tensor::cat(
1174 &sin.split(&self.mrope_section, D::Minus1)?
1175 .into_iter()
1176 .enumerate()
1177 .map(|(i, m)| m.i(i % 3))
1178 .collect::<Result<Vec<_>>>()?,
1179 D::Minus1,
1180 )?
1181 .squeeze(0)?
1182 .to_dtype(dtype)?
1183 .contiguous()?;
1184
1185 Ok((cos, sin))
1186 }
1187
1188 pub fn forward(
1190 &self,
1191 (cos, sin): &(Tensor, Tensor),
1192 q: &mut Tensor,
1193 k: &mut Tensor,
1194 ) -> Result<()> {
1195 *q = candle_nn::rotary_emb::rope(&q.contiguous()?, cos, sin)?;
1196 *k = candle_nn::rotary_emb::rope(&k.contiguous()?, cos, sin)?;
1197 Ok(())
1198 }
1199}
1200
1201#[derive(Debug, Clone)]
1202pub struct Qwen2_5VLRotaryEmbedding {
1203 inv_freq: Tensor,
1204 mrope_section: Vec<usize>,
1205}
1206
1207impl Qwen2_5VLRotaryEmbedding {
1208 pub fn new(
1209 base: f32,
1210 head_dim: usize,
1211 device: &Device,
1212 mrope_section: Vec<usize>,
1213 ) -> Result<Self> {
1214 let inv_freq: Vec<_> = (0..head_dim)
1215 .step_by(2)
1216 .map(|i| 1f32 / base.powf(i as f32 / head_dim as f32))
1217 .collect();
1218 let inv_freq_len = inv_freq.len();
1219 let inv_freq = Tensor::from_vec(inv_freq, (inv_freq_len,), device)?.to_dtype(DType::F32)?;
1220 Ok(Self {
1221 inv_freq,
1222 mrope_section,
1223 })
1224 }
1225
1226 pub fn compute_cos_sin(&self, position_ids: &Tensor, dtype: DType) -> Result<(Tensor, Tensor)> {
1228 let inv_freq_expanded =
1229 self.inv_freq
1230 .reshape((1, 1, (), 1))?
1231 .repeat((3, position_ids.dim(1)?, 1, 1))?;
1232 let position_ids_expanded = position_ids.unsqueeze(2)?;
1233 let freqs = inv_freq_expanded
1234 .matmul(&position_ids_expanded.to_dtype(inv_freq_expanded.dtype())?)?
1235 .transpose(2, 3)?;
1236 let cos = freqs.cos()?;
1237 let sin = freqs.sin()?;
1238
1239 let cos = Tensor::cat(
1240 &cos.split(&self.mrope_section, D::Minus1)?
1241 .into_iter()
1242 .enumerate()
1243 .map(|(i, m)| m.i(i % 3))
1244 .collect::<Result<Vec<_>>>()?,
1245 D::Minus1,
1246 )?
1247 .squeeze(0)?
1248 .to_dtype(dtype)?
1249 .contiguous()?;
1250 let sin = Tensor::cat(
1251 &sin.split(&self.mrope_section, D::Minus1)?
1252 .into_iter()
1253 .enumerate()
1254 .map(|(i, m)| m.i(i % 3))
1255 .collect::<Result<Vec<_>>>()?,
1256 D::Minus1,
1257 )?
1258 .squeeze(0)?
1259 .to_dtype(dtype)?
1260 .contiguous()?;
1261
1262 Ok((cos, sin))
1263 }
1264
1265 pub fn forward(
1266 &self,
1267 (cos, sin): &(Tensor, Tensor),
1268 q: &mut Tensor,
1269 k: &mut Tensor,
1270 ) -> Result<()> {
1271 *q = candle_nn::rotary_emb::rope(&q.contiguous()?, cos, sin)?;
1272 *k = candle_nn::rotary_emb::rope(&k.contiguous()?, cos, sin)?;
1273 Ok(())
1274 }
1275}
1276
1277#[derive(Debug, Clone)]
1278pub struct Qwen3VLRotaryEmbedding {
1279 inv_freq: Tensor,
1280 mrope_section: Vec<usize>,
1281}
1282
1283impl Qwen3VLRotaryEmbedding {
1284 pub fn new(
1285 base: f32,
1286 head_dim: usize,
1287 device: &Device,
1288 mrope_section: Vec<usize>,
1289 ) -> Result<Self> {
1290 let inv_freq: Vec<_> = (0..head_dim)
1291 .step_by(2)
1292 .map(|i| 1f32 / base.powf(i as f32 / head_dim as f32))
1293 .collect();
1294 let inv_freq_len = inv_freq.len();
1295 let inv_freq = Tensor::from_vec(inv_freq, (inv_freq_len,), device)?.to_dtype(DType::F32)?;
1296 Ok(Self {
1297 inv_freq,
1298 mrope_section,
1299 })
1300 }
1301
1302 pub fn compute_cos_sin(&self, position_ids: &Tensor, dtype: DType) -> Result<(Tensor, Tensor)> {
1304 let inv_freq_expanded =
1305 self.inv_freq
1306 .reshape((1, 1, (), 1))?
1307 .repeat((3, position_ids.dim(1)?, 1, 1))?;
1308 let position_ids_expanded = position_ids.unsqueeze(2)?;
1309 let freqs = inv_freq_expanded
1310 .matmul(&position_ids_expanded.to_dtype(inv_freq_expanded.dtype())?)?
1311 .transpose(2, 3)?;
1312 let cos = freqs.cos()?;
1313 let sin = freqs.sin()?;
1314
1315 let cos = Tensor::cat(
1316 &cos.split(&self.mrope_section, D::Minus1)?
1317 .into_iter()
1318 .enumerate()
1319 .map(|(i, m)| m.i(i % 3))
1320 .collect::<Result<Vec<_>>>()?,
1321 D::Minus1,
1322 )?
1323 .squeeze(0)?
1324 .to_dtype(dtype)?
1325 .contiguous()?;
1326 let sin = Tensor::cat(
1327 &sin.split(&self.mrope_section, D::Minus1)?
1328 .into_iter()
1329 .enumerate()
1330 .map(|(i, m)| m.i(i % 3))
1331 .collect::<Result<Vec<_>>>()?,
1332 D::Minus1,
1333 )?
1334 .squeeze(0)?
1335 .to_dtype(dtype)?
1336 .contiguous()?;
1337
1338 Ok((cos, sin))
1339 }
1340
1341 pub fn forward(
1342 &self,
1343 (cos, sin): &(Tensor, Tensor),
1344 q: &mut Tensor,
1345 k: &mut Tensor,
1346 ) -> Result<()> {
1347 *q = candle_nn::rotary_emb::rope(&q.contiguous()?, cos, sin)?;
1348 *k = candle_nn::rotary_emb::rope(&k.contiguous()?, cos, sin)?;
1349 Ok(())
1350 }
1351}
1352
1353#[derive(Debug, Clone)]
1354pub struct DeepSeekV2RotaryEmbedding {
1355 sin: Tensor,
1356 cos: Tensor,
1357}
1358
1359#[derive(Debug, Clone, Deserialize, Serialize)]
1360#[serde(untagged)]
1361pub enum DeepSeekV2RopeScaling {
1362 Yarn {
1363 original_max_position_embeddings: usize,
1364 beta_fast: f32,
1365 beta_slow: f32,
1366 mscale: f32,
1367 mscale_all_dim: f32,
1368 factor: f32,
1369 #[serde(rename = "type")]
1370 scaling_type: ScaledRopeType,
1371 },
1372 LinearOrDynamic {
1373 #[serde(rename = "type")]
1374 scaling_type: ScaledRopeType,
1375 factor: f64,
1376 },
1377}
1378
1379pub struct DeepSeekV2RopeConfig {
1380 pub rope_scaling: Option<DeepSeekV2RopeScaling>,
1381 pub max_position_embeddings: usize,
1382 pub rope_theta: f32,
1383 pub qk_rope_head_dim: usize,
1384}
1385
1386impl DeepSeekV2RotaryEmbedding {
1387 fn new_unscaled(cfg: &DeepSeekV2RopeConfig, dtype: DType, dev: &Device) -> Result<Self> {
1388 let max_seq_len = cfg.max_position_embeddings;
1389 let dim = cfg.qk_rope_head_dim;
1390
1391 let inv_freq: Vec<_> = (0..dim)
1392 .step_by(2)
1393 .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / dim as f32))
1394 .collect();
1395 let inv_freq_len = inv_freq.len();
1396 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
1397 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
1398 .to_dtype(DType::F32)?
1399 .reshape((max_seq_len, 1))?;
1400 let freqs = t.matmul(&inv_freq)?;
1401
1402 let sin = freqs.sin()?.to_dtype(dtype)?;
1403 let cos = freqs.cos()?.to_dtype(dtype)?;
1404
1405 Ok(Self { sin, cos })
1406 }
1407
1408 fn yarn_find_correction_dim(
1409 num_rot: f32,
1410 dim: usize,
1411 base: f32,
1412 max_position_embeddings: usize,
1413 ) -> f32 {
1414 (dim as f32 * (max_position_embeddings as f32 / (num_rot * 2. * PI)).ln())
1415 / (2. * base.ln())
1416 }
1417
1418 fn yarn_find_correction_range(
1419 low_rot: f32,
1420 high_rot: f32,
1421 dim: usize,
1422 base: f32,
1423 max_position_embeddings: usize,
1424 ) -> (f32, f32) {
1425 let low =
1426 Self::yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings).floor();
1427 let high =
1428 Self::yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings).ceil();
1429 (low.max(0.), high.min(dim as f32 - 1.))
1430 }
1431
1432 fn yarn_linear_ramp_mask(min: f32, mut max: f32, dim: usize, dev: &Device) -> Result<Tensor> {
1433 if min == max {
1434 max += 0.001;
1436 }
1437 let linear_func =
1438 ((Tensor::arange(0f32, dim as f32, dev)? - min as f64)? / (max as f64 - min as f64))?;
1439 linear_func.clamp(0., 1)
1440 }
1441
1442 pub(crate) fn yarn_get_mscale(scale: f32, mscale: f32) -> f32 {
1443 if scale <= 1. {
1444 return 1.;
1445 }
1446 0.1 * mscale * scale.ln() + 1.
1447 }
1448
1449 #[allow(clippy::too_many_arguments)]
1450 fn new_yarn(
1451 cfg: &DeepSeekV2RopeConfig,
1452 dtype: DType,
1453 dev: &Device,
1454 original_max_position_embeddings: usize,
1455 beta_fast: f32,
1456 beta_slow: f32,
1457 factor: f32,
1458 mscale: f32,
1459 mscale_all_dim: f32,
1460 ) -> Result<Self> {
1461 let freq_extra: Vec<_> = (0..cfg.qk_rope_head_dim)
1462 .step_by(2)
1463 .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / cfg.qk_rope_head_dim as f32))
1464 .collect();
1465 let freq_extra_len = freq_extra.len();
1466 let freq_extra = Tensor::from_vec(freq_extra, freq_extra_len, dev)?;
1467 let freq_inter: Vec<_> = (0..cfg.qk_rope_head_dim)
1468 .step_by(2)
1469 .map(|i| 1f32 / (factor * cfg.rope_theta.powf(i as f32 / cfg.qk_rope_head_dim as f32)))
1470 .collect();
1471 let freq_inter_len = freq_inter.len();
1472 let freq_inter = Tensor::from_vec(freq_inter, (1, freq_inter_len), dev)?;
1473
1474 let (low, high) = Self::yarn_find_correction_range(
1475 beta_fast,
1476 beta_slow,
1477 cfg.qk_rope_head_dim,
1478 cfg.rope_theta,
1479 original_max_position_embeddings,
1480 );
1481 let inv_freq_mask =
1482 (1. - Self::yarn_linear_ramp_mask(low, high, cfg.qk_rope_head_dim / 2, dev)?)?;
1483 let inv_freq = freq_inter
1484 .broadcast_mul(&(1. - &inv_freq_mask)?)?
1485 .broadcast_add(&freq_extra.broadcast_mul(&inv_freq_mask)?)?;
1486
1487 let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
1488 .to_dtype(DType::F32)?
1489 .reshape((cfg.max_position_embeddings, 1))?;
1490 let freqs = t.matmul(&inv_freq)?;
1491
1492 let mscale =
1493 Self::yarn_get_mscale(factor, mscale) / Self::yarn_get_mscale(factor, mscale_all_dim);
1494 let sin = (freqs.sin()? * mscale as f64)?.to_dtype(dtype)?;
1495 let cos = (freqs.cos()? * mscale as f64)?.to_dtype(dtype)?;
1496
1497 Ok(Self { sin, cos })
1498 }
1499
1500 pub fn new(cfg: &DeepSeekV2RopeConfig, dtype: DType, dev: &Device) -> Result<Self> {
1501 match &cfg.rope_scaling {
1502 Some(DeepSeekV2RopeScaling::LinearOrDynamic {
1503 scaling_type: _,
1504 factor: _,
1505 }) => candle_core::bail!("linear and dynamic rope are not implemented yet!"),
1506 Some(DeepSeekV2RopeScaling::Yarn {
1507 original_max_position_embeddings,
1508 beta_fast,
1509 beta_slow,
1510 factor,
1511 mscale,
1512 mscale_all_dim,
1513 scaling_type: _,
1514 }) => Self::new_yarn(
1515 cfg,
1516 dtype,
1517 dev,
1518 *original_max_position_embeddings,
1519 *beta_fast,
1520 *beta_slow,
1521 *factor,
1522 *mscale,
1523 *mscale_all_dim,
1524 ),
1525 None => Self::new_unscaled(cfg, dtype, dev),
1526 }
1527 }
1528
1529 pub fn forward(
1530 &self,
1531 q: &Tensor,
1532 k: &Tensor,
1533 seqlen_offsets: &[usize],
1534 ) -> Result<(Tensor, Tensor)> {
1535 let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
1536
1537 if seqlen_offsets.len() == 1 {
1538 let cos = self.cos.narrow(0, seqlen_offsets[0], seq_len)?;
1539 let sin = self.sin.narrow(0, seqlen_offsets[0], seq_len)?;
1540 let q_embed = candle_nn::rotary_emb::rope_i(&q.contiguous()?, &cos, &sin)?;
1541 let k_embed = candle_nn::rotary_emb::rope_i(&k.contiguous()?, &cos, &sin)?;
1542 Ok((q_embed, k_embed))
1543 } else {
1544 let mut q_embeds = Vec::new();
1545 let mut k_embeds = Vec::new();
1546 for (i, offset) in seqlen_offsets.iter().enumerate() {
1547 let cos = self.cos.narrow(0, *offset, seq_len)?;
1548 let sin = self.sin.narrow(0, *offset, seq_len)?;
1549 let q_embed = candle_nn::rotary_emb::rope_i(
1550 &q.i(i)?.unsqueeze(0)?.contiguous()?,
1551 &cos,
1552 &sin,
1553 )?;
1554 let k_embed = candle_nn::rotary_emb::rope_i(
1555 &k.i(i)?.unsqueeze(0)?.contiguous()?,
1556 &cos,
1557 &sin,
1558 )?;
1559 q_embeds.push(q_embed);
1560 k_embeds.push(k_embed);
1561 }
1562 Ok((Tensor::cat(&q_embeds, 0)?, Tensor::cat(&k_embeds, 0)?))
1563 }
1564 }
1565}
1566
1567#[derive(Debug, Clone)]
1568pub struct Phi4MMRotaryEmbedding {
1569 short_sin: Tensor,
1570 short_cos: Tensor,
1571 long_cos: Option<Tensor>,
1572 long_sin: Option<Tensor>,
1573 original_max_position_embeddings: usize,
1574}
1575
1576#[derive(Debug, Clone, Default, Deserialize, Serialize)]
1577#[serde(rename_all = "lowercase")]
1578pub enum Phi4MMScaledRopeType {
1579 #[serde(alias = "longrope")]
1580 LongRope,
1581 #[default]
1582 Default,
1583}
1584
1585#[derive(Debug, Clone, Deserialize, Serialize)]
1586pub struct Phi4MMRopeScalingConfig {
1587 short_factor: Option<Vec<f64>>,
1588 long_factor: Option<Vec<f64>>,
1589 #[serde(rename = "type")]
1590 scaling_type: Phi4MMScaledRopeType,
1591}
1592
1593impl Phi4MMRotaryEmbedding {
1594 fn new_unscaled(cfg: &Phi4MMConfig, dtype: DType, dev: &Device) -> Result<Self> {
1595 let max_seq_len = cfg.max_position_embeddings;
1596 let dim = (cfg.head_dim() as f64 * cfg.partial_rotary_factor) as usize;
1597
1598 let inv_freq: Vec<_> = (0..dim)
1599 .step_by(2)
1600 .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
1601 .collect();
1602 let inv_freq_len = inv_freq.len();
1603 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
1604 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
1605 .to_dtype(DType::F32)?
1606 .reshape((max_seq_len, 1))?;
1607 let freqs = t.matmul(&inv_freq)?;
1608 let sin = freqs.sin()?.to_dtype(dtype)?;
1609 let cos = freqs.cos()?.to_dtype(dtype)?;
1610 Ok(Self {
1611 short_cos: cos,
1612 short_sin: sin,
1613 long_cos: None,
1614 long_sin: None,
1615 original_max_position_embeddings: cfg.original_max_position_embeddings,
1616 })
1617 }
1618
1619 #[allow(clippy::too_many_arguments)]
1620 fn new_longrope(
1621 short_factor: &[f64],
1622 long_factor: &[f64],
1623 cfg: &Phi4MMConfig,
1624 dtype: DType,
1625 dev: &Device,
1626 ) -> Result<Self> {
1627 let max_seq_len = cfg.max_position_embeddings;
1628 let dim = (cfg.head_dim() as f64 * cfg.partial_rotary_factor) as usize;
1629
1630 let scale =
1632 cfg.max_position_embeddings as f64 / cfg.original_max_position_embeddings as f64;
1633 let scaling_factor = if scale <= 1.0 {
1634 1.0
1635 } else {
1636 (1.0 + scale.ln() / (cfg.original_max_position_embeddings as f64).ln()).sqrt()
1637 };
1638
1639 let inv_freq_short: Vec<_> = (0..dim)
1641 .step_by(2)
1642 .enumerate()
1643 .map(|(k, i)| {
1644 1f32 / (short_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32
1645 })
1646 .collect();
1647 let inv_freq_len_short = inv_freq_short.len();
1648 let inv_freq_short = Tensor::from_vec(inv_freq_short, (1, inv_freq_len_short), dev)?;
1649 let t_short = Tensor::arange(0u32, max_seq_len as u32, dev)?
1650 .to_dtype(DType::F32)?
1651 .reshape((max_seq_len, 1))?;
1652 let freqs_short = t_short.matmul(&inv_freq_short)?;
1653 let sin_short = (freqs_short.sin()?.to_dtype(dtype)? * scaling_factor)?;
1654 let cos_short = (freqs_short.cos()?.to_dtype(dtype)? * scaling_factor)?;
1655
1656 let inv_freq_long: Vec<_> = (0..dim)
1658 .step_by(2)
1659 .enumerate()
1660 .map(|(k, i)| {
1661 1f32 / (long_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32
1662 })
1663 .collect();
1664 let inv_freq_len_long = inv_freq_long.len();
1665 let inv_freq_long = Tensor::from_vec(inv_freq_long, (1, inv_freq_len_long), dev)?;
1666 let t_long = Tensor::arange(0u32, max_seq_len as u32, dev)?
1667 .to_dtype(DType::F32)?
1668 .reshape((max_seq_len, 1))?;
1669 let freqs_long = t_long.matmul(&inv_freq_long)?;
1670 let sin_long = (freqs_long.sin()?.to_dtype(dtype)? * scaling_factor)?;
1671 let cos_long = (freqs_long.cos()?.to_dtype(dtype)? * scaling_factor)?;
1672
1673 Ok(Self {
1674 short_cos: cos_short,
1675 short_sin: sin_short,
1676 long_cos: Some(cos_long),
1677 long_sin: Some(sin_long),
1678 original_max_position_embeddings: cfg.original_max_position_embeddings,
1679 })
1680 }
1681
1682 pub fn new(dtype: DType, cfg: &Phi4MMConfig, dev: &Device) -> Result<Self> {
1683 match &cfg.rope_scaling {
1684 Some(Phi4MMRopeScalingConfig {
1685 scaling_type: Phi4MMScaledRopeType::LongRope,
1686 short_factor: Some(short_factor),
1687 long_factor: Some(long_factor),
1688 }) => Self::new_longrope(short_factor, long_factor, cfg, dtype, dev),
1689
1690 _ => Self::new_unscaled(cfg, dtype, dev),
1691 }
1692 }
1693
1694 fn get_long_or_short_sin_cos(&self, position_ids: &[usize]) -> (&Tensor, &Tensor) {
1696 if self.long_cos.is_none() {
1697 return (&self.short_sin, &self.short_cos);
1698 }
1699 let seq_len = position_ids.iter().max().unwrap() + 1;
1700 if seq_len > self.original_max_position_embeddings {
1701 (
1702 self.long_sin.as_ref().unwrap(),
1703 self.long_cos.as_ref().unwrap(),
1704 )
1705 } else {
1706 (&self.short_sin, &self.short_cos)
1707 }
1708 }
1709
1710 pub fn forward(
1711 &self,
1712 q: &Tensor,
1713 k: &Tensor,
1714 seqlen_offsets: &[usize],
1715 position_ids: &[usize],
1716 ) -> Result<(Tensor, Tensor)> {
1717 let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
1718 let (sin, cos) = self.get_long_or_short_sin_cos(position_ids);
1719
1720 let rot_dim = cos.dim(D::Minus1)? * 2;
1721 let q_rot = q.narrow(D::Minus1, 0, rot_dim)?;
1722 let q_pass = q.narrow(D::Minus1, rot_dim, q.dim(D::Minus1)? - rot_dim)?;
1723 let k_rot = k.narrow(D::Minus1, 0, rot_dim)?;
1724 let k_pass = k.narrow(D::Minus1, rot_dim, k.dim(D::Minus1)? - rot_dim)?;
1725
1726 let (q_rot, k_rot) = if seqlen_offsets.len() == 1 {
1727 let cos = cos.narrow(0, seqlen_offsets[0], seq_len)?;
1728 let sin = sin.narrow(0, seqlen_offsets[0], seq_len)?;
1729 let q_embed = candle_nn::rotary_emb::rope(&q_rot.contiguous()?, &cos, &sin)?;
1730 let k_embed = candle_nn::rotary_emb::rope(&k_rot.contiguous()?, &cos, &sin)?;
1731 (q_embed, k_embed)
1732 } else {
1733 let mut q_embeds = Vec::new();
1734 let mut k_embeds = Vec::new();
1735 for (i, offset) in seqlen_offsets.iter().enumerate() {
1736 let cos = cos.narrow(0, *offset, seq_len)?;
1737 let sin = sin.narrow(0, *offset, seq_len)?;
1738 let q_embed = candle_nn::rotary_emb::rope(
1739 &q_rot.i(i)?.unsqueeze(0)?.contiguous()?,
1740 &cos,
1741 &sin,
1742 )?;
1743 let k_embed = candle_nn::rotary_emb::rope(
1744 &k_rot.i(i)?.unsqueeze(0)?.contiguous()?,
1745 &cos,
1746 &sin,
1747 )?;
1748 q_embeds.push(q_embed);
1749 k_embeds.push(k_embed);
1750 }
1751 let q_rot = Tensor::cat(&q_embeds, 0)?;
1752 let k_rot = Tensor::cat(&k_embeds, 0)?;
1753 (q_rot, k_rot)
1754 };
1755
1756 Ok((
1757 Tensor::cat(&[q_rot, q_pass], D::Minus1)?.contiguous()?,
1758 Tensor::cat(&[k_rot, k_pass], D::Minus1)?.contiguous()?,
1759 ))
1760 }
1761}
1762
1763#[derive(Debug, Clone)]
1764pub struct Gemma3nRotaryEmbedding(RotaryEmbedding);
1765
1766#[derive(Debug, Clone, Deserialize, Serialize)]
1767#[serde(rename_all = "lowercase")]
1768pub enum Gemma3nScaledRopeType {
1769 #[serde(alias = "linear")]
1770 Linear,
1771}
1772
1773#[derive(Debug, Clone, Deserialize, Serialize)]
1774pub struct Gemma3nRopeScalingConfig {
1775 factor: f64,
1776 rope_type: Gemma3nScaledRopeType,
1777}
1778
1779impl Gemma3nRotaryEmbedding {
1780 fn new_linear(
1781 cfg: &Gemma3nTextConfig,
1782 factor: f64,
1783 is_gpt_neox: bool,
1784 dtype: DType,
1785 dev: &Device,
1786 ) -> Result<Self> {
1787 let max_seq_len = cfg.max_position_embeddings;
1788 let dim = cfg.head_dim;
1789
1790 let inv_freq: Vec<_> = (0..dim)
1791 .step_by(2)
1792 .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
1793 .collect();
1794 let inv_freq_len = inv_freq.len();
1795 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
1796 let inv_freq = (inv_freq / factor)?;
1797
1798 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
1799 .to_dtype(DType::F32)?
1800 .reshape((max_seq_len, 1))?;
1801 let freqs = t.matmul(&inv_freq)?;
1802 let sin = freqs.sin()?.to_dtype(dtype)?;
1803 let cos = freqs.cos()?.to_dtype(dtype)?;
1804 Ok(Self(RotaryEmbedding {
1805 cos,
1806 sin,
1807 is_gpt_neox,
1808 }))
1809 }
1810
1811 pub fn new(
1812 is_gpt_neox: bool,
1813 dtype: DType,
1814 cfg: &Gemma3nTextConfig,
1815 dev: &Device,
1816 ) -> Result<Self> {
1817 match &cfg.rope_scaling {
1818 Some(Gemma3RopeScalingConfig {
1819 rope_type: Gemma3ScaledRopeType::Linear,
1820 factor,
1821 }) => Self::new_linear(cfg, *factor, is_gpt_neox, dtype, dev),
1822
1823 _ => Self::new_linear(cfg, 1.0, is_gpt_neox, dtype, dev),
1824 }
1825 }
1826
1827 pub fn get_cos_sin(&self) -> Result<(Tensor, Tensor)> {
1828 self.0.get_cos_sin()
1829 }
1830
1831 pub fn forward(
1832 &self,
1833 q: &Tensor,
1834 k: &Tensor,
1835 seqlen_offsets: &[usize],
1836 ) -> Result<(Tensor, Tensor)> {
1837 self.0.forward(q, k, seqlen_offsets)
1838 }
1839}
1840
1841#[derive(Debug, Clone)]
1842pub struct Gemma3RotaryEmbedding(RotaryEmbedding);
1843
1844#[derive(Debug, Clone, Deserialize, Serialize)]
1845#[serde(rename_all = "lowercase")]
1846pub enum Gemma3ScaledRopeType {
1847 #[serde(alias = "linear")]
1848 Linear,
1849}
1850
1851#[derive(Debug, Clone, Deserialize, Serialize)]
1852pub struct Gemma3RopeScalingConfig {
1853 factor: f64,
1854 rope_type: Gemma3ScaledRopeType,
1855}
1856
1857impl Gemma3RotaryEmbedding {
1858 fn new_linear(
1859 cfg: &Gemma3TextConfig,
1860 factor: f64,
1861 is_gpt_neox: bool,
1862 dtype: DType,
1863 dev: &Device,
1864 ) -> Result<Self> {
1865 let max_seq_len = cfg.max_position_embeddings;
1866 let dim = cfg.head_dim;
1867
1868 let inv_freq: Vec<_> = (0..dim)
1869 .step_by(2)
1870 .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
1871 .collect();
1872 let inv_freq_len = inv_freq.len();
1873 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
1874 let inv_freq = (inv_freq / factor)?;
1875
1876 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
1877 .to_dtype(DType::F32)?
1878 .reshape((max_seq_len, 1))?;
1879 let freqs = t.matmul(&inv_freq)?;
1880 let sin = freqs.sin()?.to_dtype(dtype)?;
1881 let cos = freqs.cos()?.to_dtype(dtype)?;
1882 Ok(Self(RotaryEmbedding {
1883 cos,
1884 sin,
1885 is_gpt_neox,
1886 }))
1887 }
1888
1889 pub fn new(
1890 is_gpt_neox: bool,
1891 dtype: DType,
1892 cfg: &Gemma3TextConfig,
1893 dev: &Device,
1894 ) -> Result<Self> {
1895 match &cfg.rope_scaling {
1896 Some(Gemma3RopeScalingConfig {
1897 rope_type: Gemma3ScaledRopeType::Linear,
1898 factor,
1899 }) => Self::new_linear(cfg, *factor, is_gpt_neox, dtype, dev),
1900
1901 _ => Self::new_linear(cfg, 1.0, is_gpt_neox, dtype, dev),
1902 }
1903 }
1904
1905 fn new_linear_embedding_gemma(
1906 cfg: &EmbeddingGemmaConfig,
1907 factor: f64,
1908 is_gpt_neox: bool,
1909 dtype: DType,
1910 dev: &Device,
1911 ) -> Result<Self> {
1912 let max_seq_len = cfg.max_position_embeddings;
1913 let dim = cfg.head_dim;
1914
1915 let inv_freq: Vec<_> = (0..dim)
1916 .step_by(2)
1917 .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
1918 .collect();
1919 let inv_freq_len = inv_freq.len();
1920 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
1921 let inv_freq = (inv_freq / factor)?;
1922
1923 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
1924 .to_dtype(DType::F32)?
1925 .reshape((max_seq_len, 1))?;
1926 let freqs = t.matmul(&inv_freq)?;
1927 let sin = freqs.sin()?.to_dtype(dtype)?;
1928 let cos = freqs.cos()?.to_dtype(dtype)?;
1929 Ok(Self(RotaryEmbedding {
1930 cos,
1931 sin,
1932 is_gpt_neox,
1933 }))
1934 }
1935
1936 pub fn new_embedding_gemma(
1937 is_gpt_neox: bool,
1938 dtype: DType,
1939 cfg: &EmbeddingGemmaConfig,
1940 dev: &Device,
1941 ) -> Result<Self> {
1942 match &cfg.rope_scaling {
1943 Some(Gemma3RopeScalingConfig {
1944 rope_type: Gemma3ScaledRopeType::Linear,
1945 factor,
1946 }) => Self::new_linear_embedding_gemma(cfg, *factor, is_gpt_neox, dtype, dev),
1947
1948 _ => Self::new_linear_embedding_gemma(cfg, 1.0, is_gpt_neox, dtype, dev),
1949 }
1950 }
1951
1952 pub fn forward(
1953 &self,
1954 q: &Tensor,
1955 k: &Tensor,
1956 seqlen_offsets: &[usize],
1957 ) -> Result<(Tensor, Tensor)> {
1958 self.0.forward(q, k, seqlen_offsets)
1959 }
1960}
1961
1962pub struct DiaRotaryEmbedding {
1963 timescale: Tensor,
1964 dtype: DType,
1965}
1966
1967impl DiaRotaryEmbedding {
1968 pub fn new(
1969 min_timescale: f32,
1970 max_timescale: f32,
1971 head_dim: usize,
1972 device: &Device,
1973 dtype: DType,
1974 ) -> Result<Self> {
1975 assert_eq!(head_dim % 2, 0);
1976 let half_embedding_dim = head_dim / 2;
1977
1978 let fraction = (0..half_embedding_dim).map(|i| 2f32 * i as f32 / head_dim as f32);
1979 let timescale = fraction
1980 .into_iter()
1981 .map(|x| min_timescale * (max_timescale / min_timescale).powf(x))
1982 .collect::<Vec<_>>();
1983
1984 let timescale_len = timescale.len();
1985 let timescale = Tensor::from_vec(timescale, timescale_len, device)?;
1986
1987 Ok(Self { timescale, dtype })
1988 }
1989
1990 pub fn forward(&self, xs: &Tensor, positions: &Tensor) -> Result<Tensor> {
1991 let freqs = positions
1992 .unsqueeze(D::Minus1)?
1993 .unsqueeze(D::Minus1)?
1994 .broadcast_div(&self.timescale)?;
1995
1996 let sin = freqs.sin()?.to_dtype(self.dtype)?;
1997 let cos = freqs.cos()?.to_dtype(self.dtype)?;
1998
1999 let split = xs.chunk(2, D::Minus1)?;
2000 let first_half = &split[0];
2001 let second_half = &split[1];
2002
2003 let first_part = (first_half.broadcast_mul(&cos)? - second_half.broadcast_mul(&sin)?)?;
2004 let second_part = (second_half.broadcast_mul(&cos)? + first_half.broadcast_mul(&sin)?)?;
2005
2006 Tensor::cat(&[first_part, second_part], D::Minus1)
2007 }
2008}
2009#[derive(Debug, Clone)]
2010pub struct QLinear {
2011 inner: QMatMul,
2012 bias: Option<Tensor>,
2013 dtype: DType,
2014}
2015
2016impl QLinear {
2017 pub fn new<R: std::io::Read + std::io::Seek>(
2018 ct: &mut Content<'_, R>,
2019 name: &str,
2020 device: &Device,
2021 ) -> Result<Self> {
2022 let w = ct.tensor(&format!("{name}.weight"), device)?;
2023 let b = ct.tensor(&format!("{name}.bias"), device)?;
2024 let inner = QMatMul::from_qtensor(w)?;
2025 let bias = b.dequantize(device)?;
2026 Ok(Self {
2027 inner,
2028 bias: Some(bias),
2029 dtype: DType::F32,
2030 })
2031 }
2032
2033 pub fn from_linear(linear: Linear) -> Self {
2034 Self {
2035 inner: QMatMul::Tensor(linear.weight().clone()),
2036 bias: linear.bias().cloned(),
2037 dtype: linear.weight().dtype(),
2038 }
2039 }
2040
2041 pub fn from_parts(w: Tensor, b: Option<Tensor>) -> Self {
2042 let dtype = w.dtype();
2043 Self {
2044 inner: QMatMul::Tensor(w),
2045 bias: b,
2046 dtype,
2047 }
2048 }
2049
2050 pub fn from_qparts(w: QTensor, b: Option<Tensor>) -> Self {
2051 if let Some(ref b) = b {
2052 assert_eq!(b.dtype(), DType::F32);
2053 }
2054 Self {
2055 inner: QMatMul::QTensor(Arc::new(w)),
2056 bias: b,
2057 dtype: DType::F32,
2058 }
2059 }
2060
2061 pub fn from_old_and_qmatmul(inner: QMatMul, old: &Self) -> Self {
2062 Self {
2063 inner,
2064 bias: old.bias.clone(),
2065 dtype: old.dtype,
2066 }
2067 }
2068
2069 pub fn inner(&mut self) -> &mut QMatMul {
2070 &mut self.inner
2071 }
2072
2073 pub fn inner_ref(&self) -> &QMatMul {
2074 &self.inner
2075 }
2076
2077 pub fn is_quant(&self) -> bool {
2078 matches!(self.inner, QMatMul::QTensor(_))
2079 }
2080
2081 pub fn bias(&self) -> Option<&Tensor> {
2082 self.bias.as_ref()
2083 }
2084
2085 pub fn bias_mut(&mut self) -> Option<&mut Tensor> {
2086 self.bias.as_mut()
2087 }
2088}
2089
2090impl Module for QLinear {
2091 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2092 let xs = if self.is_quant() {
2093 xs.to_dtype(DType::F32)?
2094 } else {
2095 xs.clone()
2096 };
2097 if let Some(bias) = &self.bias {
2098 self.inner
2099 .forward(&xs)?
2100 .broadcast_add(bias)?
2101 .to_dtype(self.dtype)
2102 } else {
2103 self.inner.forward(&xs)?.to_dtype(self.dtype)
2104 }
2105 }
2106}
2107
2108#[derive(Debug, Clone)]
2109pub struct RotaryEmbedding {
2110 cos: Tensor,
2111 sin: Tensor,
2112 is_gpt_neox: bool,
2113}
2114
2115impl RotaryEmbedding {
2116 pub fn new(
2117 base: f32,
2118 head_dim: usize,
2119 max_position_embeddings: usize,
2120 device: &Device,
2121 is_gpt_neox: bool,
2122 dtype: DType,
2123 ) -> Result<Self> {
2124 let inv_freq: Vec<_> = (0..head_dim)
2125 .step_by(2)
2126 .map(|i| 1f32 / base.powf(i as f32 / head_dim as f32))
2127 .collect();
2128 let inv_freq_len = inv_freq.len();
2129 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), device)?;
2130 let t = Tensor::arange(0u32, max_position_embeddings as u32, device)?
2131 .to_dtype(DType::F32)?
2132 .reshape((max_position_embeddings, 1))?;
2133 let freqs = t.matmul(&inv_freq)?;
2134 let sin = freqs.sin()?.to_dtype(dtype)?;
2135 let cos = freqs.cos()?.to_dtype(dtype)?;
2136
2137 Ok(Self {
2138 cos,
2139 sin,
2140 is_gpt_neox,
2141 })
2142 }
2143
2144 pub fn get_cos_sin(&self) -> Result<(Tensor, Tensor)> {
2145 Ok((self.cos.clone(), self.sin.clone()))
2146 }
2147
2148 pub fn new_partial(
2149 base: f32,
2150 rot_dim: usize,
2151 max_position_embeddings: usize,
2152 device: &Device,
2153 is_gpt_neox: bool,
2154 dtype: DType,
2155 ) -> Result<Self> {
2156 let inv_freq: Vec<_> = (0..rot_dim)
2157 .step_by(2)
2158 .map(|i| 1f32 / base.powf(i as f32 / rot_dim as f32))
2159 .collect();
2160 let inv_freq_len = inv_freq.len();
2161 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), device)?;
2162 let t = Tensor::arange(0u32, max_position_embeddings as u32, device)?
2163 .to_dtype(DType::F32)?
2164 .reshape((max_position_embeddings, 1))?;
2165 let freqs = t.matmul(&inv_freq)?;
2166 let sin = freqs.sin()?.to_dtype(dtype)?;
2167 let cos = freqs.cos()?.to_dtype(dtype)?;
2168
2169 Ok(Self {
2170 cos,
2171 sin,
2172 is_gpt_neox,
2173 })
2174 }
2175
2176 pub fn forward(
2177 &self,
2178 q: &Tensor,
2179 k: &Tensor,
2180 seqlen_offsets: &[usize],
2181 ) -> Result<(Tensor, Tensor)> {
2182 let (b_sz, qh, seq_len, n_embd) = q.dims4()?;
2183 let (_b_sz, kh, _seq_len, __n_embd) = k.dims4()?;
2184
2185 let rope = if self.is_gpt_neox {
2186 candle_nn::rotary_emb::rope
2187 } else {
2188 candle_nn::rotary_emb::rope_i
2189 };
2190
2191 if cfg!(feature = "cuda") && qh == kh {
2192 let (cos, sin) = if seqlen_offsets.len() == 1 {
2193 (
2194 self.cos.narrow(0, seqlen_offsets[0], seq_len)?,
2195 self.sin.narrow(0, seqlen_offsets[0], seq_len)?,
2196 )
2197 } else {
2198 let mut cos_s = Vec::new();
2199 let mut sin_s = Vec::new();
2200 for offset in seqlen_offsets {
2201 cos_s.push(self.cos.narrow(0, *offset, seq_len)?);
2202 sin_s.push(self.sin.narrow(0, *offset, seq_len)?);
2203 }
2204 (Tensor::cat(&cos_s, 0)?, Tensor::cat(&sin_s, 0)?)
2205 };
2206
2207 let q_embed = q.transpose(1, 2)?.flatten(0, 1)?;
2208 let k_embed = k.transpose(1, 2)?.flatten(0, 1)?;
2209 mistralrs_quant::rotary::apply_rotary_inplace(
2210 &q_embed,
2211 &k_embed,
2212 &cos,
2213 &sin,
2214 self.is_gpt_neox,
2215 )?;
2216 let mut q = q_embed
2217 .reshape((b_sz, seq_len, qh, n_embd))?
2218 .transpose(1, 2)?;
2219 let mut k = k_embed
2220 .reshape((b_sz, seq_len, kh, n_embd))?
2221 .transpose(1, 2)?;
2222 if !(cfg!(feature = "flash-attn") || cfg!(feature = "flash-attn-v3")) {
2223 q = q.contiguous()?;
2224 k = k.contiguous()?;
2225 }
2226 Ok((q, k))
2227 } else if seqlen_offsets.len() == 1 {
2228 let cos = self.cos.narrow(0, seqlen_offsets[0], seq_len)?;
2229 let sin = self.sin.narrow(0, seqlen_offsets[0], seq_len)?;
2230 let q_embed = rope(&q.contiguous()?, &cos, &sin)?;
2231 let k_embed = rope(&k.contiguous()?, &cos, &sin)?;
2232 Ok((q_embed, k_embed))
2233 } else {
2234 let mut q_embeds = Vec::new();
2235 let mut k_embeds = Vec::new();
2236 for (i, offset) in seqlen_offsets.iter().enumerate() {
2237 let cos = self.cos.narrow(0, *offset, seq_len)?;
2238 let sin = self.sin.narrow(0, *offset, seq_len)?;
2239 let q_embed = rope(&q.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
2240 let k_embed = rope(&k.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
2241 q_embeds.push(q_embed);
2242 k_embeds.push(k_embed);
2243 }
2244 Ok((Tensor::cat(&q_embeds, 0)?, Tensor::cat(&k_embeds, 0)?))
2245 }
2246 }
2247}
2248
2249#[derive(Debug, Clone, Copy, PartialEq, Deserialize, Serialize, Default)]
2250#[serde(rename_all = "lowercase")]
2251pub enum Activation {
2252 #[default]
2253 #[serde(alias = "gelu")]
2254 Gelu,
2255 #[serde(alias = "gelu_new")]
2256 NewGelu,
2257 Relu,
2258 Relu2,
2259 Relu6,
2260 Silu,
2261 Sigmoid,
2262 HardSigmoid,
2263 Swiglu,
2264 Swish,
2265 HardSwish,
2266 Elu(f64),
2267 LeakyRelu(f64),
2268 #[serde(alias = "gelu_pytorch_tanh")]
2269 GeluPytorchTanh,
2270 QuickGelu,
2271}
2272
2273impl Module for Activation {
2274 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2275 match self {
2276 Self::Gelu => xs.gelu_erf(),
2277 Self::NewGelu => xs.gelu(),
2279 Self::Relu => xs.relu(),
2280 Self::Relu2 => xs.relu()?.sqr(),
2281 Self::Relu6 => xs.clamp(0f32, 6f32),
2282 Self::Silu => xs.silu(),
2283 Self::Sigmoid => candle_nn::ops::sigmoid(xs),
2284 Self::HardSigmoid => candle_nn::ops::hard_sigmoid(xs),
2285 Self::Swiglu => candle_nn::ops::swiglu(xs),
2286 Self::Swish => xs * candle_nn::ops::sigmoid(xs)?,
2287 Self::HardSwish => xs * candle_nn::ops::hard_sigmoid(xs)?,
2288 &Self::Elu(alpha) => xs.elu(alpha),
2289 &Self::LeakyRelu(negative_slope) => candle_nn::ops::leaky_relu(xs, negative_slope),
2290 Self::GeluPytorchTanh => xs.gelu(),
2291 Self::QuickGelu => xs * candle_nn::ops::sigmoid(&(xs * 1.702f64)?),
2292 }
2293 }
2294}
2295
2296impl TryInto<candle_nn::Activation> for Activation {
2297 type Error = candle_core::Error;
2298
2299 fn try_into(self) -> Result<candle_nn::Activation> {
2300 match self {
2301 Self::Gelu => Ok(candle_nn::Activation::Gelu),
2302 Self::Relu => Ok(candle_nn::Activation::Relu),
2303 Self::Silu => Ok(candle_nn::Activation::Silu),
2304 Self::NewGelu => Ok(candle_nn::Activation::NewGelu),
2305 Self::Relu2 => Ok(candle_nn::Activation::Relu2),
2306 Self::Relu6 => Ok(candle_nn::Activation::Relu6),
2307 Self::Sigmoid => Ok(candle_nn::Activation::Sigmoid),
2308 Self::HardSigmoid => Ok(candle_nn::Activation::HardSigmoid),
2309 Self::Swiglu => Ok(candle_nn::Activation::Swiglu),
2310 Self::Swish => Ok(candle_nn::Activation::Swish),
2311 Self::HardSwish => Ok(candle_nn::Activation::HardSwish),
2312 Self::Elu(x) => Ok(candle_nn::Activation::Elu(x)),
2313 Self::LeakyRelu(x) => Ok(candle_nn::Activation::LeakyRelu(x)),
2314 Self::GeluPytorchTanh => Ok(candle_nn::Activation::GeluPytorchTanh),
2315 Self::QuickGelu => candle_core::bail!("No mapping to candle_nn for QuickGelu"),
2316 }
2317 }
2318}
2319
2320#[derive(Debug, Clone, Copy, PartialEq, Eq)]
2321pub struct Conv3dConfig {
2322 pub padding: usize,
2323 pub stride: usize,
2324 pub dilation: usize,
2325 pub groups: usize,
2326}
2327
2328impl Default for Conv3dConfig {
2329 fn default() -> Self {
2330 Self {
2331 padding: 0,
2332 stride: 1,
2333 dilation: 1,
2334 groups: 1,
2335 }
2336 }
2337}
2338
2339pub struct Conv3dNoBias {
2340 conv2d_1: Conv2d,
2341 conv2d_2: Conv2d,
2342}
2343
2344impl Conv3dNoBias {
2345 pub fn new(
2346 in_channels: usize,
2347 out_channels: usize,
2348 kernel_sizes: [usize; 3],
2349 cfg: Conv3dConfig,
2350 vb: ShardedVarBuilder,
2351 ) -> Result<Self> {
2352 let ws = vb.get(
2353 (
2354 out_channels,
2355 in_channels / cfg.groups,
2356 kernel_sizes[0],
2357 kernel_sizes[1],
2358 kernel_sizes[2],
2359 ),
2360 "weight",
2361 )?;
2362
2363 let w1 = ws.i((.., .., 0, .., ..))?;
2367 let w2 = ws.i((.., .., 1, .., ..))?;
2368
2369 let cfg = Conv2dConfig {
2370 padding: cfg.padding,
2371 stride: cfg.stride,
2372 dilation: cfg.dilation,
2373 groups: cfg.groups,
2374 cudnn_fwd_algo: None,
2375 };
2376
2377 Ok(Self {
2378 conv2d_1: Conv2d::new(w1.contiguous()?, None, cfg),
2379 conv2d_2: Conv2d::new(w2.contiguous()?, None, cfg),
2380 })
2381 }
2382
2383 pub fn weight(&self) -> Result<Tensor> {
2384 let w1 = self.conv2d_1.weight().clone().unsqueeze(2)?;
2385 let w2 = self.conv2d_2.weight().clone().unsqueeze(2)?;
2386 Tensor::cat(&[w1, w2], 2)
2387 }
2388}
2389
2390impl Module for Conv3dNoBias {
2391 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2392 let xs1 = xs.i((.., .., 0, .., ..))?;
2393 let xs2 = xs.i((.., .., 1, .., ..))?;
2394
2395 (Convolution.forward_2d(&self.conv2d_1, &xs1)?
2396 + Convolution.forward_2d(&self.conv2d_2, &xs2)?)?
2397 .unsqueeze(2)
2398 }
2399}
2400
2401pub trait TensorInfExtend {
2402 fn is_inf(&self) -> Result<Self>
2403 where
2404 Self: Sized;
2405 fn any(&self) -> Result<bool>;
2406}
2407
2408impl TensorInfExtend for Tensor {
2409 fn is_inf(&self) -> Result<Self> {
2410 self.broadcast_eq(&Tensor::new(f32::INFINITY, self.device())?.to_dtype(self.dtype())?)
2411 }
2412
2413 fn any(&self) -> Result<bool> {
2414 let sum = self.sum_all()?;
2415 match self.dtype() {
2416 DType::U8 => Ok(sum.to_scalar::<u8>()? == 0),
2417 DType::U32 => Ok(sum.to_scalar::<u32>()? == 0),
2418 DType::I16 => Ok(sum.to_scalar::<i16>()? == 0),
2419 DType::I32 => Ok(sum.to_scalar::<i32>()? == 0),
2420 DType::I64 => Ok(sum.to_scalar::<i64>()? == 0),
2421 DType::F16 => Ok(sum.to_scalar::<half::f16>()? == half::f16::from_f32_const(0.)),
2422 DType::BF16 => Ok(sum.to_scalar::<half::bf16>()? == half::bf16::from_f32_const(0.)),
2423 DType::F32 => Ok(sum.to_scalar::<f32>()? == 0.),
2424 DType::F64 => Ok(sum.to_scalar::<f64>()? == 0.),
2425 DType::F8E4M3 => Ok(sum.to_scalar::<F8E4M3>()? == F8E4M3::ZERO),
2426 DType::F4 | DType::F6E3M2 | DType::F6E2M3 | DType::F8E8M0 => {
2427 candle_core::bail!("f4/f6e3m2/f6e2m3/f8e8m0 tensors are not supported with .any")
2428 }
2429 }
2430 }
2431}
2432
2433pub fn clamp_for_f16(xs: &Tensor) -> Result<Tensor> {
2434 let mut max = match xs.dtype() {
2435 DType::U8 => u8::MAX as f32 - 1000.,
2436 DType::U32 => u32::MAX as f32 - 1000.,
2437 DType::I16 => i16::MAX as f32 - 1000.,
2438 DType::I32 => i32::MAX as f32 - 1000.,
2439 DType::I64 => i64::MAX as f32 - 1000.,
2440 DType::F16 => half::f16::MAX.to_f32_const() - 1000.,
2441 DType::BF16 => half::bf16::MAX.to_f32_const() - 1000.,
2442 DType::F32 => f32::MAX - 1000.,
2443 DType::F64 => f64::MAX as f32 - 1000.,
2444 DType::F8E4M3 => F8E4M3::MAX.to_f32() - 1000.,
2445 DType::F4 | DType::F6E3M2 | DType::F6E2M3 | DType::F8E8M0 => {
2446 candle_core::bail!("f4/f6e3m2/f6e2m3/f8e8m0 tensors are not supported with .any")
2447 }
2448 };
2449 if xs.is_inf()?.any()? {
2450 max -= 1000.;
2451 }
2452 xs.clamp(-max, max)
2453}
2454
2455pub struct FloatInfo {
2456 pub min: f64,
2458 pub max: f64,
2460 pub eps: f64,
2462 pub dtype: DType,
2463}
2464
2465pub trait GetFloatInfo {
2466 fn finfo(&self) -> Result<FloatInfo>;
2467}
2468
2469impl GetFloatInfo for DType {
2470 fn finfo(&self) -> Result<FloatInfo> {
2471 let finfo = match self {
2472 Self::BF16 => FloatInfo {
2473 min: bf16::MIN.to_f64(),
2474 max: bf16::MAX.to_f64(),
2475 eps: bf16::EPSILON.to_f64(),
2476 dtype: DType::BF16,
2477 },
2478 Self::F16 => FloatInfo {
2479 min: f16::MIN.to_f64(),
2480 max: f16::MAX.to_f64(),
2481 eps: f16::EPSILON.to_f64(),
2482 dtype: DType::F16,
2483 },
2484 Self::F32 => FloatInfo {
2485 min: f32::MIN as f64,
2486 max: f32::MAX as f64,
2487 eps: f32::EPSILON as f64,
2488 dtype: DType::F32,
2489 },
2490 Self::F64 => FloatInfo {
2491 min: f64::MIN,
2492 max: f64::MAX,
2493 eps: f64::EPSILON,
2494 dtype: DType::F64,
2495 },
2496 Self::F8E4M3 => FloatInfo {
2497 min: F8E4M3::MIN.to_f64(),
2498 max: F8E4M3::MAX.to_f64(),
2499 eps: F8E4M3::EPSILON.to_f64(),
2500 dtype: DType::F8E4M3,
2501 },
2502 other => {
2503 candle_core::bail!("Expected a float type for `GetFloatInfo`, got {other:?}");
2504 }
2505 };
2506 Ok(finfo)
2507 }
2508}
2509
2510#[derive(Clone)]
2511pub struct Mlp {
2512 pub gate: Arc<dyn QuantMethod>,
2513 pub up: Arc<dyn QuantMethod>,
2514 pub down: Arc<dyn QuantMethod>,
2515 act: Activation,
2516 params: Vec<usize>,
2517}
2518
2519impl Mlp {
2520 pub fn new(
2521 vb: ShardedVarBuilder,
2522 hidden_size: usize,
2523 intermediate_size: usize,
2524 quantization_config: &Option<QuantizedConfig>,
2525 hidden_act: Activation,
2526 comm: &Arc<mistralrs_quant::Comm>,
2527 ) -> Result<Self> {
2528 Ok(Self {
2529 gate: ColumnParallelLayer::new(
2530 hidden_size,
2531 intermediate_size,
2532 quantization_config,
2533 false,
2534 comm,
2535 vb.pp("gate_proj"),
2536 )?,
2537 up: ColumnParallelLayer::new(
2538 hidden_size,
2539 intermediate_size,
2540 quantization_config,
2541 false,
2542 comm,
2543 vb.pp("up_proj"),
2544 )?,
2545 down: RowParallelLayer::new(
2546 intermediate_size,
2547 hidden_size,
2548 quantization_config,
2549 false,
2550 comm,
2551 vb.pp("down_proj"),
2552 )?,
2553 act: hidden_act,
2554 params: vec![hidden_size, intermediate_size],
2555 })
2556 }
2557
2558 pub fn new_merged(
2559 vb: ShardedVarBuilder,
2560 hidden_size: usize,
2561 intermediate_size: usize,
2562 chunks: usize,
2563 quantization_config: &Option<QuantizedConfig>,
2564 hidden_act: Activation,
2565 comm: &Arc<mistralrs_quant::Comm>,
2566 ) -> Result<Self> {
2567 assert!(chunks == 2, "Only gate_up_proj merge is supported!");
2568 let gate_up_projs = ColumnParallelLayer::new_merged(
2569 hidden_size,
2570 intermediate_size * 2,
2571 2,
2572 quantization_config,
2573 false,
2574 comm,
2575 vb.pp("gate_up_proj"),
2576 )?;
2577
2578 Ok(Self {
2579 gate: gate_up_projs[0].to_owned(),
2580 up: gate_up_projs[1].to_owned(),
2581 down: RowParallelLayer::new(
2582 intermediate_size,
2583 hidden_size,
2584 quantization_config,
2585 false,
2586 comm,
2587 vb.pp("down_proj"),
2588 )?,
2589 act: hidden_act,
2590 params: vec![hidden_size, intermediate_size],
2591 })
2592 }
2593
2594 pub fn replicate(
2595 params: &[usize],
2596 vb: ShardedVarBuilder,
2597 act: Activation,
2598 comm: &Arc<mistralrs_quant::Comm>,
2599 ) -> Result<Self> {
2600 Self::new(vb, params[0], params[1], &None, act, comm)
2601 }
2602
2603 pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2604 let original_dtype = xs.dtype();
2605 let mut xs = xs.clone();
2606 if let Some(t) = self.gate.quantized_act_type() {
2607 xs = xs.to_dtype(t)?;
2608 }
2609 let lhs = self.gate.forward(&xs)?;
2610 let rhs = self.up.forward(&xs)?;
2611 let mut res = self
2612 .down
2613 .forward(&crate::ops::mul_and_act(&lhs, &rhs, self.act)?)?;
2614 if self.gate.quantized_act_type().is_some() {
2615 res = res.to_dtype(original_dtype)?;
2616 }
2617 Ok(res)
2618 }
2619}
2620
2621impl AnyMoeTrainableLayer for Mlp {}
2622
2623impl MlpLayer for Mlp {
2624 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2625 let original_dtype = xs.dtype();
2626 let mut xs = xs.clone();
2627 if let Some(t) = self.gate.quantized_act_type() {
2628 xs = xs.to_dtype(t)?;
2629 }
2630 let lhs = MatMul.qmethod_matmul(&xs, &*self.gate)?;
2631 let rhs = MatMul.qmethod_matmul(&xs, &*self.up)?;
2632 let mut res =
2633 MatMul.qmethod_matmul(&crate::ops::mul_and_act(&lhs, &rhs, self.act)?, &*self.down)?;
2634 if self.gate.quantized_act_type().is_some() {
2635 res = res.to_dtype(original_dtype)?;
2636 }
2637 Ok(res)
2638 }
2639 fn get_isq_layers(&mut self) -> Vec<&mut Arc<dyn QuantMethod>> {
2640 vec![&mut self.gate, &mut self.up, &mut self.down]
2641 }
2642 fn clone(&self) -> Box<dyn MlpLayer> {
2643 Box::new(Clone::clone(self))
2644 }
2645 fn get_params(&self) -> &[usize] {
2646 &self.params
2647 }
2648 fn hidden_act(&self) -> Activation {
2649 self.act
2650 }
2651 fn new_added_delta(&self, deltas: Vec<Option<Tensor>>) -> Result<Box<dyn MlpLayer>> {
2653 let gate = if let Some(ref delta) = deltas[0] {
2654 self.gate.add_delta_w(delta)?
2655 } else {
2656 self.gate.clone()
2657 };
2658 let up = if let Some(ref delta) = deltas[1] {
2659 self.up.add_delta_w(delta)?
2660 } else {
2661 self.up.clone()
2662 };
2663 let down = if let Some(ref delta) = deltas[2] {
2664 self.down.add_delta_w(delta)?
2665 } else {
2666 self.down.clone()
2667 };
2668
2669 Ok(Box::new(Self {
2670 gate,
2671 up,
2672 down,
2673 act: self.act,
2674 params: self.params.clone(),
2675 }))
2676 }
2677
2678 fn dtype_device(&self) -> (DType, Device) {
2679 self.gate.dtype_and_device()
2680 }
2681}
2682
2683pub struct AvgPool2d {
2684 kernel_size: usize,
2685 stride: usize,
2686}
2687
2688impl AvgPool2d {
2689 pub fn new(kernel_size: usize, stride: usize) -> Self {
2690 Self {
2691 kernel_size,
2692 stride,
2693 }
2694 }
2695
2696 pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2697 xs.avg_pool2d_with_stride(self.kernel_size, self.stride)
2698 }
2699}
2700
2701pub struct ReflectionPad2d {
2708 padding: (usize, usize, usize, usize),
2709}
2710
2711impl ReflectionPad2d {
2712 pub fn new(padding: (usize, usize, usize, usize)) -> Self {
2713 Self { padding }
2714 }
2715}
2716
2717impl Module for ReflectionPad2d {
2718 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2719 let (pad_left, pad_right, pad_top, pad_bottom) = self.padding;
2720
2721 let (_n, _c, h, w) = xs.dims4()?;
2722
2723 let left_pad = if pad_left > 0 {
2726 let indices: Vec<i64> = (1..=pad_left as i64).rev().collect();
2728 Some(xs.index_select(&Tensor::new(indices, &Device::Cpu)?, 3)?)
2729 } else {
2730 None
2731 };
2732
2733 let right_pad = if pad_right > 0 {
2735 let start = w as i64 - 2;
2737 let indices: Vec<i64> = (0..pad_right as i64).map(|i| start - i).collect();
2738 Some(xs.index_select(&Tensor::new(indices, &Device::Cpu)?, 3)?)
2739 } else {
2740 None
2741 };
2742
2743 let x_padded_width = match (left_pad, right_pad) {
2745 (Some(l), Some(r)) => Tensor::cat(&[l, xs.clone(), r], 3)?,
2746 (Some(l), None) => Tensor::cat(&[l, xs.clone()], 3)?,
2747 (None, Some(r)) => Tensor::cat(&[xs.clone(), r], 3)?,
2748 (None, None) => xs.clone(),
2749 };
2750
2751 let top_pad = if pad_top > 0 {
2754 let indices: Vec<i64> = (1..=pad_top as i64).rev().collect();
2755 Some(x_padded_width.index_select(&Tensor::new(indices, &Device::Cpu)?, 2)?)
2756 } else {
2757 None
2758 };
2759
2760 let bottom_pad = if pad_bottom > 0 {
2762 let start = h as i64 - 2;
2763 let indices: Vec<i64> = (0..pad_bottom as i64).map(|i| start - i).collect();
2764 Some(x_padded_width.index_select(&Tensor::new(indices, &Device::Cpu)?, 2)?)
2765 } else {
2766 None
2767 };
2768
2769 let x_padded = match (top_pad, bottom_pad) {
2771 (Some(t), Some(b)) => Tensor::cat(&[t, x_padded_width, b], 2)?,
2772 (Some(t), None) => Tensor::cat(&[t, x_padded_width], 2)?,
2773 (None, Some(b)) => Tensor::cat(&[x_padded_width, b], 2)?,
2774 (None, None) => x_padded_width,
2775 };
2776
2777 Ok(x_padded)
2778 }
2779}
2780
2781pub struct ScaledEmbedding {
2782 scale: f64,
2783 pub embedding: Tensor,
2784}
2785
2786impl ScaledEmbedding {
2787 pub fn new(scale: f64, embedding: Embedding) -> Self {
2788 Self {
2789 scale,
2790 embedding: embedding.embeddings().clone(),
2791 }
2792 }
2793
2794 pub fn embeddings(&self) -> &Tensor {
2795 &self.embedding
2796 }
2797}
2798
2799impl Module for ScaledEmbedding {
2800 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2801 let embedding = Embedding::new(self.embedding.clone(), self.embedding.dim(D::Minus1)?);
2802 xs.apply(&embedding)? * self.scale
2803 }
2804}