1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{f32::consts::PI, ops::Mul, str::FromStr, sync::Arc};
4
5use candle_core::{
6 quantized::{QMatMul, QTensor},
7 Context, DType, Device, IndexOp, Result, Tensor, D,
8};
9use candle_nn::{
10 BatchNorm, BatchNormConfig, Conv1d, Conv1dConfig, Conv2d, Conv2dConfig, Embedding, GroupNorm,
11 LayerNorm, LayerNormConfig, Linear, Module,
12};
13use float8::F8E4M3;
14use half::{bf16, f16};
15use mistralrs_quant::{
16 AfqLayer, ColumnParallelLayer, Convolution, QuantMethod, QuantizedConfig, RowParallelLayer,
17 ShardedVarBuilder,
18};
19use serde::{Deserialize, Serialize};
20
21pub use crate::attention::Sdpa;
22pub use crate::layers_masker::CausalMasker;
23pub use crate::layers_utils::repeat_kv;
24use crate::{
25 amoe::{AnyMoeTrainableLayer, MlpLayer},
26 gguf::Content,
27 models::{llama, smollm3},
28 ops::SplitOp,
29 vision_models::{
30 gemma3::config::Gemma3TextConfig,
31 gemma3n::config::Gemma3nTextConfig,
32 llama4,
33 mllama::{MLlamaRopeScaling, MLlamaRopeType, MLlamaTextConfig},
34 phi4::Phi4MMConfig,
35 },
36};
37
38pub use mistralrs_quant::MatMul;
39
40pub fn embedding(
41 in_size: usize,
42 out_size: usize,
43 vb: ShardedVarBuilder,
44 config: &Option<QuantizedConfig>,
45) -> Result<Embedding> {
46 let embeddings = if let Some(QuantizedConfig::Afq { .. }) = config {
48 let afq_layer =
49 AfqLayer::afq_linear_b(out_size, in_size, config.as_ref().unwrap(), false, vb)?;
50 afq_layer.dequantize_w()?
51 } else {
52 vb.get_with_hints((in_size, out_size), "weight", Default::default())?
53 };
54 Ok(Embedding::new(embeddings, out_size))
55}
56
57pub fn layer_norm<C: Into<LayerNormConfig>>(
58 size: usize,
59 config: C,
60 vb: ShardedVarBuilder,
61) -> Result<LayerNorm> {
62 let config = config.into();
63 let weight = vb.get(size, "weight")?;
64 if config.affine {
65 let bias = vb.get(size, "bias")?;
66 Ok(LayerNorm::new(weight, bias, config.eps))
67 } else {
68 Ok(LayerNorm::new_no_bias(weight, config.eps))
69 }
70}
71
72pub fn batch_norm<C: Into<BatchNormConfig>>(
73 num_features: usize,
74 config: C,
75 vb: ShardedVarBuilder,
76) -> Result<BatchNorm> {
77 let config = config.into();
78 if config.eps < 0. {
79 candle_core::bail!("batch-norm eps cannot be negative {}", config.eps)
80 }
81 let running_mean = vb.get(num_features, "running_mean")?;
82 let running_var = vb.get(num_features, "running_var")?;
83
84 if config.affine {
85 let weight = vb.get(num_features, "weight")?;
86 let bias = vb.get(num_features, "bias")?;
87 BatchNorm::new(
88 num_features,
89 running_mean,
90 running_var,
91 weight,
92 bias,
93 config.eps,
94 )
95 } else {
96 BatchNorm::new_no_bias(num_features, running_mean, running_var, config.eps)
97 }
98}
99
100pub fn group_norm(
101 num_groups: usize,
102 num_channels: usize,
103 eps: f64,
104 vb: ShardedVarBuilder,
105) -> Result<GroupNorm> {
106 let weight = vb.get(num_channels, "weight")?;
107 let bias = vb.get(num_channels, "bias")?;
108 GroupNorm::new(weight, bias, num_channels, num_groups, eps)
109}
110
111pub fn conv2d(
112 in_channels: usize,
113 out_channels: usize,
114 kernel_size: usize,
115 cfg: Conv2dConfig,
116 vb: ShardedVarBuilder,
117) -> Result<Conv2d> {
118 let ws = vb.get(
119 (
120 out_channels,
121 in_channels / cfg.groups,
122 kernel_size,
123 kernel_size,
124 ),
125 "weight",
126 )?;
127 let bs = vb.get(out_channels, "bias")?;
128 Ok(Conv2d::new(ws, Some(bs), cfg))
129}
130
131pub fn conv2d_no_bias(
132 in_channels: usize,
133 out_channels: usize,
134 kernel_size: usize,
135 cfg: Conv2dConfig,
136 vb: ShardedVarBuilder,
137) -> Result<Conv2d> {
138 let ws = vb.get(
139 (
140 out_channels,
141 in_channels / cfg.groups,
142 kernel_size,
143 kernel_size,
144 ),
145 "weight",
146 )?;
147 Ok(Conv2d::new(ws, None, cfg))
148}
149
150pub fn conv1d(
151 in_channels: usize,
152 out_channels: usize,
153 kernel_size: usize,
154 cfg: Conv1dConfig,
155 vb: ShardedVarBuilder,
156) -> Result<Conv1d> {
157 let ws = vb.get(
158 (out_channels, in_channels / cfg.groups, kernel_size),
159 "weight",
160 )?;
161 let bs = vb.get(out_channels, "bias")?;
162 Ok(Conv1d::new(ws, Some(bs), cfg))
163}
164
165pub fn conv1d_no_bias(
166 in_channels: usize,
167 out_channels: usize,
168 kernel_size: usize,
169 cfg: Conv1dConfig,
170 vb: ShardedVarBuilder,
171) -> Result<Conv1d> {
172 let ws = vb.get(
173 (out_channels, in_channels / cfg.groups, kernel_size),
174 "weight",
175 )?;
176 Ok(Conv1d::new(ws, None, cfg))
177}
178
179pub fn linear(in_dim: usize, out_dim: usize, vb: ShardedVarBuilder) -> Result<Linear> {
180 let ws = vb.get((out_dim, in_dim), "weight")?;
181 let bs = vb.get(out_dim, "bias")?;
182 Ok(Linear::new(ws, Some(bs)))
183}
184
185pub fn linear_no_bias(in_dim: usize, out_dim: usize, vb: ShardedVarBuilder) -> Result<Linear> {
186 let ws = vb.get((out_dim, in_dim), "weight")?;
187 Ok(Linear::new(ws, None))
188}
189
190pub fn linear_b(
191 in_dim: usize,
192 out_dim: usize,
193 bias: bool,
194 vb: ShardedVarBuilder,
195) -> Result<Linear> {
196 if bias {
197 linear(in_dim, out_dim, vb)
198 } else {
199 linear_no_bias(in_dim, out_dim, vb)
200 }
201}
202
203#[derive(Debug, Clone)]
204pub struct RmsNorm {
205 eps: f64,
206 weight: Tensor,
207}
208
209impl RmsNorm {
210 pub fn new(size: usize, eps: f64, vb: ShardedVarBuilder) -> Result<Self> {
211 let w = vb.get(size, "weight")?;
212 Ok(Self { eps, weight: w })
213 }
214
215 pub fn new_gemma(size: usize, eps: f64, vb: ShardedVarBuilder) -> Result<Self> {
217 let w = vb.get(size, "weight")?;
218 let w = (w + 1.0)?;
219 Ok(Self { eps, weight: w })
220 }
221
222 pub fn new_gemma_3n(
224 size: usize,
225 eps: f64,
226 with_scale: bool,
227 vb: ShardedVarBuilder,
228 ) -> Result<Self> {
229 let w = if with_scale {
230 vb.get(size, "weight")?
231 } else {
232 Tensor::ones(size, vb.dtype(), vb.device())?
233 };
234 Ok(Self { eps, weight: w })
235 }
236
237 pub fn undo_gemma(&self) -> Result<Self> {
239 Ok(Self {
240 eps: self.eps,
241 weight: (&self.weight - 1.0)?,
242 })
243 }
244
245 pub fn from_w(w: Tensor, eps: f64) -> Result<Self> {
246 Ok(Self { eps, weight: w })
247 }
248
249 pub fn weight(&self) -> &Tensor {
250 &self.weight
251 }
252}
253
254impl Module for RmsNorm {
255 fn forward(&self, x: &Tensor) -> Result<Tensor> {
256 candle_nn::ops::rms_norm(&x.contiguous()?, &self.weight, self.eps as f32)
257 }
258}
259
260#[derive(Debug, Clone)]
261pub struct F32RmsNorm {
262 w: Tensor,
263 eps: f64,
264}
265
266impl F32RmsNorm {
267 pub fn new(size: usize, eps: f64, vb: ShardedVarBuilder) -> Result<Self> {
268 Ok(Self {
269 w: vb.get((size,), "weight")?,
270 eps,
271 })
272 }
273
274 pub fn weight(&self) -> &Tensor {
275 &self.w
276 }
277}
278
279impl Module for F32RmsNorm {
280 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
281 let initial_type = xs.dtype();
282 let mut xs = xs.to_dtype(DType::F32)?;
283 let var = xs.powf(2.)?.mean_keepdim(D::Minus1)?;
284 xs = xs.broadcast_mul(&(&var + self.eps)?.recip()?.sqrt()?)?;
285 xs.to_dtype(initial_type)?.broadcast_mul(&self.w)
286 }
287}
288
289#[derive(Debug, Clone)]
290pub struct QRmsNorm {
291 eps: f64,
292 weight: Tensor,
293}
294
295impl QRmsNorm {
296 pub fn new(scale: QTensor, eps: f32) -> Result<Self> {
297 let scale = scale.dequantize(&scale.device())?;
298 Ok(Self {
299 eps: eps as f64,
300 weight: scale,
301 })
302 }
303
304 pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
305 candle_nn::ops::rms_norm(&x.contiguous()?, &self.weight, self.eps as f32)
306 }
307}
308
309#[derive(Debug, Clone)]
311pub struct PhiRotaryEmbedding {
312 short_sin: Tensor,
313 short_cos: Tensor,
314 long_cos: Option<Tensor>,
315 long_sin: Option<Tensor>,
316 original_max_position_embeddings: usize,
317}
318
319#[derive(Debug, Clone, Deserialize, Serialize)]
320#[serde(rename_all = "lowercase")]
321pub enum ScaledRopeType {
322 #[serde(alias = "su")]
323 #[serde(alias = "longrope")]
324 Su,
325 #[serde(alias = "yarn")]
326 Yarn,
327 #[serde(alias = "dynamic")]
328 Dynamic,
329 #[serde(alias = "linear")]
330 Linear,
331}
332
333impl FromStr for ScaledRopeType {
334 type Err = candle_core::Error;
335 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
336 match s {
337 "su" | "longrope" => Ok(Self::Su),
338 "yarn" => Ok(Self::Yarn),
339 "linear" => Ok(Self::Linear),
340 "dynamic" => Ok(Self::Dynamic),
341 _ => Err(candle_core::Error::Msg(
342 "Expected either `su` or `yarn` scaled RoPE type.".to_string(),
343 )),
344 }
345 }
346}
347
348#[derive(Debug, Clone, Deserialize, Serialize)]
349#[serde(untagged)]
350pub enum PhiRopeScalingConfig {
351 Classic {
352 short_factor: Vec<f64>,
353 long_factor: Vec<f64>,
354 #[serde(rename = "type")]
355 scaling_type: ScaledRopeType,
356 },
357 Scaled {
358 short_factor: Vec<f64>,
359 long_factor: Vec<f64>,
360 #[serde(rename = "type")]
361 scaling_type: ScaledRopeType,
362 long_mscale: f64,
363 short_mscale: f64,
364 },
365}
366
367pub struct PhiRopeConfig {
368 pub rope_scaling: Option<PhiRopeScalingConfig>,
369 pub max_position_embeddings: usize,
370 pub original_max_position_embeddings: usize,
371 pub rope_theta: f64,
372 pub head_dim: usize,
373 pub partial_rotary_factor: Option<f64>,
374}
375
376impl PhiRotaryEmbedding {
377 fn new_classic_scaled(
378 short_factor: &[f64],
379 long_factor: &[f64],
380 scaling_type: &ScaledRopeType,
381 cfg: &PhiRopeConfig,
382 dtype: DType,
383 dev: &Device,
384 ) -> Result<Self> {
385 let max_seq_len = cfg.max_position_embeddings;
386 let dim = (cfg.head_dim as f64 * cfg.partial_rotary_factor.unwrap_or(1.)) as usize;
387
388 let scale =
390 cfg.max_position_embeddings as f64 / cfg.original_max_position_embeddings as f64;
391 let scaling_factor = if scale <= 1.0 {
392 1.0
393 } else {
394 match scaling_type {
395 ScaledRopeType::Su => {
396 (1.0 + scale.ln() / (cfg.original_max_position_embeddings as f64).ln()).sqrt()
397 }
398 ScaledRopeType::Yarn => 0.1 * scale.ln() + 1.0,
399 _ => candle_core::bail!("Expected either `su` or `yarn` RoPE"),
400 }
401 };
402
403 let inv_freq_long = (0..dim)
405 .step_by(2)
406 .enumerate()
407 .map(|(k, i)| {
408 (1f64 / (long_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64))) as f32
409 })
410 .collect::<Vec<_>>();
411 let inv_freq_short = (0..dim)
412 .step_by(2)
413 .enumerate()
414 .map(|(k, i)| {
415 (1f64 / (short_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64))) as f32
416 })
417 .collect::<Vec<_>>();
418 let inv_freq_len = inv_freq_long.len();
419
420 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
421 .to_dtype(DType::F32)?
422 .reshape((max_seq_len, 1))?;
423
424 let inv_freq_long = Tensor::from_vec(inv_freq_long, (1, inv_freq_len), dev)?;
426 let freqs_long = t.matmul(&inv_freq_long)?;
427 let long_sin = freqs_long.sin()?.mul(scaling_factor)?.to_dtype(dtype)?;
428 let long_cos = freqs_long.cos()?.mul(scaling_factor)?.to_dtype(dtype)?;
429
430 let inv_freq_short =
432 Tensor::from_vec(inv_freq_short, (1, inv_freq_len), dev)?.to_dtype(DType::F32)?;
433 let freqs_short = t.matmul(&inv_freq_short)?;
434 let short_sin = freqs_short.sin()?.mul(scaling_factor)?.to_dtype(dtype)?;
435 let short_cos = freqs_short.cos()?.mul(scaling_factor)?.to_dtype(dtype)?;
436
437 Ok(Self {
438 short_cos,
439 short_sin,
440 long_cos: Some(long_cos),
441 long_sin: Some(long_sin),
442 original_max_position_embeddings: cfg.original_max_position_embeddings,
443 })
444 }
445
446 fn new_unscaled(cfg: &PhiRopeConfig, dtype: DType, dev: &Device) -> Result<Self> {
447 let max_seq_len = cfg.max_position_embeddings;
448 let dim = (cfg.head_dim as f64 * cfg.partial_rotary_factor.unwrap_or(1.)) as usize;
449
450 let inv_freq: Vec<_> = (0..dim)
451 .step_by(2)
452 .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
453 .collect();
454 let inv_freq_len = inv_freq.len();
455 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
456 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
457 .to_dtype(DType::F32)?
458 .reshape((max_seq_len, 1))?;
459 let freqs = t.matmul(&inv_freq)?;
460 let sin = freqs.sin()?.to_dtype(dtype)?;
461 let cos = freqs.cos()?.to_dtype(dtype)?;
462 Ok(Self {
463 short_cos: cos,
464 short_sin: sin,
465 long_cos: None,
466 long_sin: None,
467 original_max_position_embeddings: cfg.original_max_position_embeddings,
468 })
469 }
470
471 #[allow(clippy::too_many_arguments)]
472 fn new_scaled(
473 short_factor: &[f64],
474 long_factor: &[f64],
475 scaling_type: &ScaledRopeType,
476 long_mscale: f64,
477 short_mscale: f64,
478 cfg: &PhiRopeConfig,
479 dtype: DType,
480 dev: &Device,
481 ) -> Result<Self> {
482 let max_seq_len = cfg.max_position_embeddings;
483 let dim = (cfg.head_dim as f64 * cfg.partial_rotary_factor.unwrap_or(1.)) as usize;
484
485 if !matches!(scaling_type, ScaledRopeType::Su) {
486 candle_core::bail!("Scaled Phi3 RoPE (non-classic scaled, with mscales) must have type `su`/`longrope`.");
487 }
488
489 if short_factor.len() != dim / 2 {
490 candle_core::bail!(
491 "Misaligned length {}, expected {} for `su`/`longrope` short rescale factors",
492 short_factor.len(),
493 dim / 2
494 );
495 }
496 if long_factor.len() != dim / 2 {
497 candle_core::bail!(
498 "Misaligned length {}, expected {} for `su`/`longrope` long rescale factors",
499 long_factor.len(),
500 dim / 2
501 );
502 }
503
504 let inv_freq_short: Vec<_> = (0..dim)
506 .step_by(2)
507 .enumerate()
508 .map(|(k, i)| {
509 1f32 / (short_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32
510 })
511 .collect();
512 let inv_freq_len_short = inv_freq_short.len();
513 let inv_freq_short = Tensor::from_vec(inv_freq_short, (1, inv_freq_len_short), dev)?;
514 let t_short = Tensor::arange(0u32, max_seq_len as u32, dev)?
515 .to_dtype(DType::F32)?
516 .reshape((max_seq_len, 1))?;
517 let freqs_short = t_short.matmul(&inv_freq_short)?;
518 let sin_short = (freqs_short.sin()?.to_dtype(dtype)? * short_mscale)?;
519 let cos_short = (freqs_short.cos()?.to_dtype(dtype)? * short_mscale)?;
520
521 let inv_freq_long: Vec<_> = (0..dim)
523 .step_by(2)
524 .enumerate()
525 .map(|(k, i)| {
526 1f32 / (long_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32
527 })
528 .collect();
529 let inv_freq_len_long = inv_freq_long.len();
530 let inv_freq_long = Tensor::from_vec(inv_freq_long, (1, inv_freq_len_long), dev)?;
531 let t_long = Tensor::arange(0u32, max_seq_len as u32, dev)?
532 .to_dtype(DType::F32)?
533 .reshape((max_seq_len, 1))?;
534 let freqs_long = t_long.matmul(&inv_freq_long)?;
535 let sin_long = (freqs_long.sin()?.to_dtype(dtype)? * long_mscale)?;
536 let cos_long = (freqs_long.cos()?.to_dtype(dtype)? * long_mscale)?;
537 Ok(Self {
538 short_cos: cos_short,
539 short_sin: sin_short,
540 long_cos: Some(cos_long),
541 long_sin: Some(sin_long),
542 original_max_position_embeddings: cfg.original_max_position_embeddings,
543 })
544 }
545
546 pub fn new(dtype: DType, cfg: impl Into<PhiRopeConfig>, dev: &Device) -> Result<Self> {
547 let cfg: PhiRopeConfig = cfg.into();
548
549 match &cfg.rope_scaling {
550 Some(PhiRopeScalingConfig::Classic {
551 short_factor,
552 long_factor,
553 scaling_type,
554 }) => {
555 Self::new_classic_scaled(short_factor, long_factor, scaling_type, &cfg, dtype, dev)
556 }
557
558 Some(PhiRopeScalingConfig::Scaled {
559 short_factor,
560 long_factor,
561 scaling_type,
562 long_mscale,
563 short_mscale,
564 }) => Self::new_scaled(
565 short_factor,
566 long_factor,
567 scaling_type,
568 *long_mscale,
569 *short_mscale,
570 &cfg,
571 dtype,
572 dev,
573 ),
574
575 None => Self::new_unscaled(&cfg, dtype, dev),
576 }
577 }
578
579 fn get_long_or_short_sin_cos(&self, position_ids: &[usize]) -> (&Tensor, &Tensor) {
581 if self.long_cos.is_none() {
582 return (&self.short_sin, &self.short_cos);
583 }
584 let seq_len = position_ids.iter().max().unwrap() + 1;
585 if seq_len > self.original_max_position_embeddings {
586 (
587 self.long_sin.as_ref().unwrap(),
588 self.long_cos.as_ref().unwrap(),
589 )
590 } else {
591 (&self.short_sin, &self.short_cos)
592 }
593 }
594
595 pub fn forward(
596 &self,
597 q: &Tensor,
598 k: &Tensor,
599 seqlen_offsets: &[usize],
600 position_ids: &[usize],
601 ) -> Result<(Tensor, Tensor)> {
602 let (sin, cos) = self.get_long_or_short_sin_cos(position_ids);
603 let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
604
605 let rot_dim = cos.dim(D::Minus1)? * 2;
606
607 if rot_dim != q.dim(D::Minus1)? {
609 let rot_dim = cos.dim(D::Minus1)? * 2;
610 let q_rot = q.narrow(D::Minus1, 0, rot_dim)?;
611 let q_pass = q.narrow(D::Minus1, rot_dim, q.dim(D::Minus1)? - rot_dim)?;
612 let k_rot = k.narrow(D::Minus1, 0, rot_dim)?;
613 let k_pass = k.narrow(D::Minus1, rot_dim, k.dim(D::Minus1)? - rot_dim)?;
614
615 let (q_rot, k_rot) = if seqlen_offsets.len() == 1 {
616 let cos = cos.narrow(0, seqlen_offsets[0], seq_len)?;
617 let sin = sin.narrow(0, seqlen_offsets[0], seq_len)?;
618 let q_embed = candle_nn::rotary_emb::rope(&q_rot.contiguous()?, &cos, &sin)?;
619 let k_embed = candle_nn::rotary_emb::rope(&k_rot.contiguous()?, &cos, &sin)?;
620 (q_embed, k_embed)
621 } else {
622 let mut q_embeds = Vec::new();
623 let mut k_embeds = Vec::new();
624 for (i, offset) in seqlen_offsets.iter().enumerate() {
625 let cos = cos.narrow(0, *offset, seq_len)?;
626 let sin = sin.narrow(0, *offset, seq_len)?;
627 let q_embed = candle_nn::rotary_emb::rope(
628 &q_rot.i(i)?.unsqueeze(0)?.contiguous()?,
629 &cos,
630 &sin,
631 )?;
632 let k_embed = candle_nn::rotary_emb::rope(
633 &k_rot.i(i)?.unsqueeze(0)?.contiguous()?,
634 &cos,
635 &sin,
636 )?;
637 q_embeds.push(q_embed);
638 k_embeds.push(k_embed);
639 }
640 let q_rot = Tensor::cat(&q_embeds, 0)?;
641 let k_rot = Tensor::cat(&k_embeds, 0)?;
642 (q_rot, k_rot)
643 };
644
645 Ok((
646 Tensor::cat(&[q_rot, q_pass], D::Minus1)?.contiguous()?,
647 Tensor::cat(&[k_rot, k_pass], D::Minus1)?.contiguous()?,
648 ))
649 } else if seqlen_offsets.len() == 1 {
650 let cos = cos.narrow(0, seqlen_offsets[0], seq_len)?;
651 let sin = sin.narrow(0, seqlen_offsets[0], seq_len)?;
652 let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
653 let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
654 Ok((q_embed, k_embed))
655 } else {
656 let mut q_embeds = Vec::new();
657 let mut k_embeds = Vec::new();
658 for (i, offset) in seqlen_offsets.iter().enumerate() {
659 let cos = cos.narrow(0, *offset, seq_len)?;
660 let sin = sin.narrow(0, *offset, seq_len)?;
661 let q_embed =
662 candle_nn::rotary_emb::rope(&q.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
663 let k_embed =
664 candle_nn::rotary_emb::rope(&k.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
665 q_embeds.push(q_embed);
666 k_embeds.push(k_embed);
667 }
668 Ok((Tensor::cat(&q_embeds, 0)?, Tensor::cat(&k_embeds, 0)?))
669 }
670 }
671}
672
673#[derive(Debug, Clone)]
675pub struct Llama3RotaryEmbedding(RotaryEmbedding);
676
677#[derive(Debug, Clone, Deserialize, Serialize, Default)]
678pub enum Llama3RopeType {
679 #[serde(rename = "llama3")]
680 Llama3,
681 #[serde(rename = "linear")]
682 Linear,
683 #[default]
684 #[serde(rename = "default")]
685 Default,
686}
687
688#[derive(Debug, Clone, Deserialize, Serialize, Default)]
689pub struct Llama3RopeConfig {
690 pub factor: f32,
691 pub low_freq_factor: Option<f32>,
692 pub high_freq_factor: Option<f32>,
693 pub original_max_position_embeddings: Option<usize>,
694 pub rope_type: Llama3RopeType,
695}
696
697fn calculate_default_inv_freq(cfg: &llama::Config) -> Vec<f32> {
698 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
699 (0..head_dim)
700 .step_by(2)
701 .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
702 .collect()
703}
704
705fn calculate_default_inv_freq_llama4(cfg: &llama4::TextConfig) -> Vec<f32> {
706 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
707 (0..head_dim)
708 .step_by(2)
709 .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
710 .collect()
711}
712
713impl Llama3RotaryEmbedding {
715 pub fn new_llama3(
716 dtype: DType,
717 cfg: &llama::Config,
718 dev: &Device,
719 is_gpt_neox: bool,
720 ) -> Result<Self> {
721 match &cfg.rope_scaling {
722 None
723 | Some(Llama3RopeConfig {
724 rope_type: Llama3RopeType::Default,
725 ..
726 }) => Ok(Self(RotaryEmbedding::new(
727 cfg.rope_theta,
728 cfg.hidden_size / cfg.num_attention_heads,
729 cfg.max_position_embeddings,
730 dev,
731 is_gpt_neox,
732 dtype,
733 )?)),
734 Some(Llama3RopeConfig {
735 rope_type: Llama3RopeType::Llama3,
736 factor,
737 low_freq_factor,
738 high_freq_factor,
739 original_max_position_embeddings,
740 }) => {
741 let low_freq_factor = low_freq_factor.context("low_freq_factor is required")?;
742 let high_freq_factor = high_freq_factor.context("high_freq_factor is required")?;
743 let original_max_position_embeddings = original_max_position_embeddings
744 .context("original_max_position_embeddings is required")?;
745
746 let low_freq_wavelen = original_max_position_embeddings as f32 / low_freq_factor;
747 let high_freq_wavelen = original_max_position_embeddings as f32 / high_freq_factor;
748
749 let inv_freq = calculate_default_inv_freq(cfg)
750 .into_iter()
751 .map(|freq| {
752 let wavelen = 2. * PI / freq;
753 if wavelen < high_freq_wavelen {
754 freq
755 } else if wavelen > low_freq_wavelen {
756 freq / *factor
757 } else {
758 let smooth = (original_max_position_embeddings as f32 / wavelen
759 - low_freq_factor)
760 / (high_freq_factor - low_freq_factor);
761 (1. - smooth) * freq / *factor + smooth * freq
762 }
763 })
764 .collect::<Vec<_>>();
765 let inv_freq_len = inv_freq.len();
766 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
767 let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
768 .to_dtype(DType::F32)?
769 .reshape((cfg.max_position_embeddings, 1))?;
770 let freqs = t.matmul(&inv_freq)?;
771 let sin = freqs.sin()?.to_dtype(dtype)?;
772 let cos = freqs.cos()?.to_dtype(dtype)?;
773 Ok(Self(RotaryEmbedding {
774 sin,
775 cos,
776 is_gpt_neox,
777 }))
778 }
779 Some(Llama3RopeConfig {
780 rope_type: Llama3RopeType::Linear,
781 factor,
782 ..
783 }) => {
784 let inv_freq_vec = calculate_default_inv_freq(cfg)
785 .into_iter()
786 .map(|freq| freq / *factor)
787 .collect::<Vec<_>>();
788 let inv_freq_len = inv_freq_vec.len();
789 let inv_freq = Tensor::from_vec(inv_freq_vec, (1, inv_freq_len), dev)?;
790 let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
791 .to_dtype(DType::F32)?
792 .reshape((cfg.max_position_embeddings, 1))?;
793 let freqs = t.matmul(&inv_freq)?;
794 let sin = freqs.sin()?.to_dtype(dtype)?;
795 let cos = freqs.cos()?.to_dtype(dtype)?;
796 Ok(Self(RotaryEmbedding {
797 sin,
798 cos,
799 is_gpt_neox,
800 }))
801 }
802 }
803 }
804
805 pub fn new_llama4(
806 dtype: DType,
807 cfg: &llama4::TextConfig,
808 dev: &Device,
809 is_gpt_neox: bool,
810 ) -> Result<Self> {
811 match &cfg.rope_scaling {
812 None
813 | Some(Llama3RopeConfig {
814 rope_type: Llama3RopeType::Default,
815 ..
816 }) => Ok(Self(RotaryEmbedding::new(
817 cfg.rope_theta,
818 cfg.hidden_size / cfg.num_attention_heads,
819 cfg.max_position_embeddings,
820 dev,
821 is_gpt_neox,
822 dtype,
823 )?)),
824 Some(Llama3RopeConfig {
825 rope_type: Llama3RopeType::Llama3,
826 factor,
827 low_freq_factor,
828 high_freq_factor,
829 original_max_position_embeddings,
830 }) => {
831 let low_freq_factor = low_freq_factor.context("low_freq_factor is required")?;
832 let high_freq_factor = high_freq_factor.context("high_freq_factor is required")?;
833 let original_max_position_embeddings = original_max_position_embeddings
834 .context("original_max_position_embeddings is required")?;
835
836 let low_freq_wavelen = original_max_position_embeddings as f32 / low_freq_factor;
837 let high_freq_wavelen = original_max_position_embeddings as f32 / high_freq_factor;
838
839 let inv_freq = calculate_default_inv_freq_llama4(cfg)
840 .into_iter()
841 .map(|freq| {
842 let wavelen = 2. * PI / freq;
843 if wavelen < high_freq_wavelen {
844 freq
845 } else if wavelen > low_freq_wavelen {
846 freq / *factor
847 } else {
848 let smooth = (original_max_position_embeddings as f32 / wavelen
849 - low_freq_factor)
850 / (high_freq_factor - low_freq_factor);
851 (1. - smooth) * freq / *factor + smooth * freq
852 }
853 })
854 .collect::<Vec<_>>();
855 let inv_freq_len = inv_freq.len();
856 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
857 let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
858 .to_dtype(DType::F32)?
859 .reshape((cfg.max_position_embeddings, 1))?;
860 let freqs = t.matmul(&inv_freq)?;
861 let sin = freqs.sin()?.to_dtype(dtype)?;
862 let cos = freqs.cos()?.to_dtype(dtype)?;
863 Ok(Self(RotaryEmbedding {
864 sin,
865 cos,
866 is_gpt_neox,
867 }))
868 }
869 Some(Llama3RopeConfig {
870 rope_type: Llama3RopeType::Linear,
871 factor,
872 ..
873 }) => {
874 let inv_freq_vec = calculate_default_inv_freq_llama4(cfg)
875 .into_iter()
876 .map(|freq| freq / *factor)
877 .collect::<Vec<_>>();
878 let inv_freq_len = inv_freq_vec.len();
879 let inv_freq = Tensor::from_vec(inv_freq_vec, (1, inv_freq_len), dev)?;
880 let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
881 .to_dtype(DType::F32)?
882 .reshape((cfg.max_position_embeddings, 1))?;
883 let freqs = t.matmul(&inv_freq)?;
884 let sin = freqs.sin()?.to_dtype(dtype)?;
885 let cos = freqs.cos()?.to_dtype(dtype)?;
886 Ok(Self(RotaryEmbedding {
887 sin,
888 cos,
889 is_gpt_neox,
890 }))
891 }
892 }
893 }
894
895 pub fn new_mllama3(
896 dtype: DType,
897 cfg: &MLlamaTextConfig,
898 dev: &Device,
899 is_gpt_neox: bool,
900 ) -> Result<Self> {
901 match &cfg.rope_scaling {
902 None
903 | Some(MLlamaRopeScaling {
904 rope_type: MLlamaRopeType::Default,
905 ..
906 }) => Ok(Self(RotaryEmbedding::new(
907 cfg.rope_theta,
908 cfg.hidden_size / cfg.num_attention_heads,
909 cfg.max_position_embeddings,
910 dev,
911 is_gpt_neox,
912 dtype,
913 )?)),
914 Some(MLlamaRopeScaling {
915 rope_type: MLlamaRopeType::Llama3,
916 original_max_position_embeddings,
917 factor,
918 attention_factor: _,
919 beta_fast: _,
920 beta_slow: _,
921 short_factor: _,
922 long_factor: _,
923 low_freq_factor,
924 high_freq_factor,
925 }) => {
926 let factor = factor.context("MLlama Llama3 RoPE needs `factor` parameter.")?;
927 let low_freq_factor = low_freq_factor
928 .context("MLlama Llama3 RoPE needs `low_freq_factor` parameter.")?;
929 let high_freq_factor = high_freq_factor
930 .context("MLlama Llama3 RoPE needs `high_freq_factor` parameter.")?;
931
932 let low_freq_wavelen = *original_max_position_embeddings as f32 / low_freq_factor;
933 let high_freq_wavelen = *original_max_position_embeddings as f32 / high_freq_factor;
934
935 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
936
937 let inv_freq = (0..head_dim)
938 .step_by(2)
939 .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
940 .map(|freq| {
941 let wavelen = 2. * PI / freq;
942 if wavelen < high_freq_wavelen {
943 freq
944 } else if wavelen > low_freq_wavelen {
945 freq / factor
946 } else {
947 let smooth = (*original_max_position_embeddings as f32 / wavelen
948 - low_freq_factor)
949 / (high_freq_factor - low_freq_factor);
950 (1. - smooth) * freq / factor + smooth * freq
951 }
952 })
953 .collect::<Vec<_>>();
954 let inv_freq_len = inv_freq.len();
955 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
956
957 let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
958 .to_dtype(DType::F32)?
959 .reshape((cfg.max_position_embeddings, 1))?;
960 let freqs = t.matmul(&inv_freq)?;
961 let sin = freqs.sin()?.to_dtype(dtype)?;
962 let cos = freqs.cos()?.to_dtype(dtype)?;
963 Ok(Self(RotaryEmbedding {
964 sin,
965 cos,
966 is_gpt_neox,
967 }))
968 }
969 Some(MLlamaRopeScaling {
970 rope_type: other, ..
971 }) => {
972 candle_core::bail!(
973 "MLlama doesn't support any other RoPE type than `llama3`, got {other:?}"
974 )
975 }
976 }
977 }
978
979 pub fn forward(
980 &self,
981 q: &Tensor,
982 k: &Tensor,
983 seqlen_offsets: &[usize],
984 ) -> Result<(Tensor, Tensor)> {
985 self.0.forward(q, k, seqlen_offsets)
986 }
987}
988
989#[derive(Debug, Clone)]
991pub struct SmolLm3RotaryEmbedding(RotaryEmbedding);
992
993#[derive(Debug, Clone, Deserialize, Serialize, Default)]
994pub enum SmolLm3RopeType {
995 #[serde(rename = "llama3")]
996 Llama3,
997 #[serde(rename = "linear")]
998 Linear,
999 #[default]
1000 #[serde(rename = "default")]
1001 Default,
1002}
1003
1004#[derive(Debug, Clone, Deserialize, Serialize, Default)]
1005pub struct SmolLm3RopeConfig {
1006 pub factor: f32,
1007 pub low_freq_factor: Option<f32>,
1008 pub high_freq_factor: Option<f32>,
1009 pub original_max_position_embeddings: Option<usize>,
1010 pub rope_type: SmolLm3RopeType,
1011}
1012
1013fn calculate_default_inv_freq_smollm3(cfg: &smollm3::Config) -> Vec<f32> {
1014 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
1015 (0..head_dim)
1016 .step_by(2)
1017 .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
1018 .collect()
1019}
1020
1021impl SmolLm3RotaryEmbedding {
1022 pub fn new_llama3(
1023 dtype: DType,
1024 cfg: &smollm3::Config,
1025 dev: &Device,
1026 is_gpt_neox: bool,
1027 ) -> Result<Self> {
1028 match &cfg.rope_scaling {
1029 None
1030 | Some(SmolLm3RopeConfig {
1031 rope_type: SmolLm3RopeType::Default,
1032 ..
1033 }) => Ok(Self(RotaryEmbedding::new(
1034 cfg.rope_theta,
1035 cfg.hidden_size / cfg.num_attention_heads,
1036 cfg.max_position_embeddings,
1037 dev,
1038 is_gpt_neox,
1039 dtype,
1040 )?)),
1041 Some(SmolLm3RopeConfig {
1042 rope_type: SmolLm3RopeType::Llama3,
1043 factor,
1044 low_freq_factor,
1045 high_freq_factor,
1046 original_max_position_embeddings,
1047 }) => {
1048 let low_freq_factor = low_freq_factor.context("low_freq_factor is required")?;
1049 let high_freq_factor = high_freq_factor.context("high_freq_factor is required")?;
1050 let original_max_position_embeddings = original_max_position_embeddings
1051 .context("original_max_position_embeddings is required")?;
1052
1053 let low_freq_wavelen = original_max_position_embeddings as f32 / low_freq_factor;
1054 let high_freq_wavelen = original_max_position_embeddings as f32 / high_freq_factor;
1055
1056 let inv_freq = calculate_default_inv_freq_smollm3(cfg)
1057 .into_iter()
1058 .map(|freq| {
1059 let wavelen = 2. * PI / freq;
1060 if wavelen < high_freq_wavelen {
1061 freq
1062 } else if wavelen > low_freq_wavelen {
1063 freq / *factor
1064 } else {
1065 let smooth = (original_max_position_embeddings as f32 / wavelen
1066 - low_freq_factor)
1067 / (high_freq_factor - low_freq_factor);
1068 (1. - smooth) * freq / *factor + smooth * freq
1069 }
1070 })
1071 .collect::<Vec<_>>();
1072 let inv_freq_len = inv_freq.len();
1073 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
1074 let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
1075 .to_dtype(DType::F32)?
1076 .reshape((cfg.max_position_embeddings, 1))?;
1077 let freqs = t.matmul(&inv_freq)?;
1078 let sin = freqs.sin()?.to_dtype(dtype)?;
1079 let cos = freqs.cos()?.to_dtype(dtype)?;
1080 Ok(Self(RotaryEmbedding {
1081 sin,
1082 cos,
1083 is_gpt_neox,
1084 }))
1085 }
1086 Some(SmolLm3RopeConfig {
1087 rope_type: SmolLm3RopeType::Linear,
1088 factor,
1089 ..
1090 }) => {
1091 let inv_freq_vec = calculate_default_inv_freq_smollm3(cfg)
1092 .into_iter()
1093 .map(|freq| freq / *factor)
1094 .collect::<Vec<_>>();
1095 let inv_freq_len = inv_freq_vec.len();
1096 let inv_freq = Tensor::from_vec(inv_freq_vec, (1, inv_freq_len), dev)?;
1097 let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
1098 .to_dtype(DType::F32)?
1099 .reshape((cfg.max_position_embeddings, 1))?;
1100 let freqs = t.matmul(&inv_freq)?;
1101 let sin = freqs.sin()?.to_dtype(dtype)?;
1102 let cos = freqs.cos()?.to_dtype(dtype)?;
1103 Ok(Self(RotaryEmbedding {
1104 sin,
1105 cos,
1106 is_gpt_neox,
1107 }))
1108 }
1109 }
1110 }
1111
1112 pub fn forward(
1113 &self,
1114 q: &Tensor,
1115 k: &Tensor,
1116 seqlen_offsets: &[usize],
1117 ) -> Result<(Tensor, Tensor)> {
1118 self.0.forward(q, k, seqlen_offsets)
1119 }
1120}
1121
1122#[derive(Debug, Clone)]
1124pub struct Qwen2VLRotaryEmbedding {
1125 inv_freq: Tensor,
1126 mrope_section: Vec<usize>,
1127}
1128
1129impl Qwen2VLRotaryEmbedding {
1130 pub fn new(
1131 base: f32,
1132 head_dim: usize,
1133 device: &Device,
1134 mrope_section: Vec<usize>,
1135 ) -> Result<Self> {
1136 let inv_freq: Vec<_> = (0..head_dim)
1137 .step_by(2)
1138 .map(|i| 1f32 / base.powf(i as f32 / head_dim as f32))
1139 .collect();
1140 let inv_freq_len = inv_freq.len();
1141 let inv_freq = Tensor::from_vec(inv_freq, (inv_freq_len,), device)?.to_dtype(DType::F32)?;
1142 Ok(Self {
1143 inv_freq,
1144 mrope_section,
1145 })
1146 }
1147
1148 pub fn compute_cos_sin(&self, position_ids: &Tensor, dtype: DType) -> Result<(Tensor, Tensor)> {
1150 let inv_freq_expanded =
1151 self.inv_freq
1152 .reshape((1, 1, (), 1))?
1153 .repeat((3, position_ids.dim(1)?, 1, 1))?;
1154 let position_ids_expanded = position_ids.unsqueeze(2)?;
1155 let freqs = inv_freq_expanded
1156 .matmul(&position_ids_expanded.to_dtype(inv_freq_expanded.dtype())?)?
1157 .transpose(2, 3)?;
1158 let cos = freqs.cos()?;
1159 let sin = freqs.sin()?;
1160
1161 let cos = Tensor::cat(
1162 &cos.split(&self.mrope_section, D::Minus1)?
1163 .into_iter()
1164 .enumerate()
1165 .map(|(i, m)| m.i(i % 3))
1166 .collect::<Result<Vec<_>>>()?,
1167 D::Minus1,
1168 )?
1169 .squeeze(0)?
1170 .to_dtype(dtype)?
1171 .contiguous()?;
1172 let sin = Tensor::cat(
1173 &sin.split(&self.mrope_section, D::Minus1)?
1174 .into_iter()
1175 .enumerate()
1176 .map(|(i, m)| m.i(i % 3))
1177 .collect::<Result<Vec<_>>>()?,
1178 D::Minus1,
1179 )?
1180 .squeeze(0)?
1181 .to_dtype(dtype)?
1182 .contiguous()?;
1183
1184 Ok((cos, sin))
1185 }
1186
1187 pub fn forward(
1189 &self,
1190 (cos, sin): &(Tensor, Tensor),
1191 q: &mut Tensor,
1192 k: &mut Tensor,
1193 ) -> Result<()> {
1194 *q = candle_nn::rotary_emb::rope(&q.contiguous()?, cos, sin)?;
1195 *k = candle_nn::rotary_emb::rope(&k.contiguous()?, cos, sin)?;
1196 Ok(())
1197 }
1198}
1199
1200#[derive(Debug, Clone)]
1201pub struct Qwen2_5VLRotaryEmbedding {
1202 inv_freq: Tensor,
1203 mrope_section: Vec<usize>,
1204}
1205
1206impl Qwen2_5VLRotaryEmbedding {
1207 pub fn new(
1208 base: f32,
1209 head_dim: usize,
1210 device: &Device,
1211 mrope_section: Vec<usize>,
1212 ) -> Result<Self> {
1213 let inv_freq: Vec<_> = (0..head_dim)
1214 .step_by(2)
1215 .map(|i| 1f32 / base.powf(i as f32 / head_dim as f32))
1216 .collect();
1217 let inv_freq_len = inv_freq.len();
1218 let inv_freq = Tensor::from_vec(inv_freq, (inv_freq_len,), device)?.to_dtype(DType::F32)?;
1219 Ok(Self {
1220 inv_freq,
1221 mrope_section,
1222 })
1223 }
1224
1225 pub fn compute_cos_sin(&self, position_ids: &Tensor, dtype: DType) -> Result<(Tensor, Tensor)> {
1227 let inv_freq_expanded =
1228 self.inv_freq
1229 .reshape((1, 1, (), 1))?
1230 .repeat((3, position_ids.dim(1)?, 1, 1))?;
1231 let position_ids_expanded = position_ids.unsqueeze(2)?;
1232 let freqs = inv_freq_expanded
1233 .matmul(&position_ids_expanded.to_dtype(inv_freq_expanded.dtype())?)?
1234 .transpose(2, 3)?;
1235 let cos = freqs.cos()?;
1236 let sin = freqs.sin()?;
1237
1238 let cos = Tensor::cat(
1239 &cos.split(&self.mrope_section, D::Minus1)?
1240 .into_iter()
1241 .enumerate()
1242 .map(|(i, m)| m.i(i % 3))
1243 .collect::<Result<Vec<_>>>()?,
1244 D::Minus1,
1245 )?
1246 .squeeze(0)?
1247 .to_dtype(dtype)?
1248 .contiguous()?;
1249 let sin = Tensor::cat(
1250 &sin.split(&self.mrope_section, D::Minus1)?
1251 .into_iter()
1252 .enumerate()
1253 .map(|(i, m)| m.i(i % 3))
1254 .collect::<Result<Vec<_>>>()?,
1255 D::Minus1,
1256 )?
1257 .squeeze(0)?
1258 .to_dtype(dtype)?
1259 .contiguous()?;
1260
1261 Ok((cos, sin))
1262 }
1263
1264 pub fn forward(
1265 &self,
1266 (cos, sin): &(Tensor, Tensor),
1267 q: &mut Tensor,
1268 k: &mut Tensor,
1269 ) -> Result<()> {
1270 *q = candle_nn::rotary_emb::rope(&q.contiguous()?, cos, sin)?;
1271 *k = candle_nn::rotary_emb::rope(&k.contiguous()?, cos, sin)?;
1272 Ok(())
1273 }
1274}
1275
1276#[derive(Debug, Clone)]
1277pub struct Qwen3VLRotaryEmbedding {
1278 inv_freq: Tensor,
1279 mrope_section: Vec<usize>,
1280}
1281
1282impl Qwen3VLRotaryEmbedding {
1283 pub fn new(
1284 base: f32,
1285 head_dim: usize,
1286 device: &Device,
1287 mrope_section: Vec<usize>,
1288 ) -> Result<Self> {
1289 let inv_freq: Vec<_> = (0..head_dim)
1290 .step_by(2)
1291 .map(|i| 1f32 / base.powf(i as f32 / head_dim as f32))
1292 .collect();
1293 let inv_freq_len = inv_freq.len();
1294 let inv_freq = Tensor::from_vec(inv_freq, (inv_freq_len,), device)?.to_dtype(DType::F32)?;
1295 Ok(Self {
1296 inv_freq,
1297 mrope_section,
1298 })
1299 }
1300
1301 pub fn compute_cos_sin(&self, position_ids: &Tensor, dtype: DType) -> Result<(Tensor, Tensor)> {
1303 let inv_freq_expanded =
1304 self.inv_freq
1305 .reshape((1, 1, (), 1))?
1306 .repeat((3, position_ids.dim(1)?, 1, 1))?;
1307 let position_ids_expanded = position_ids.unsqueeze(2)?;
1308 let freqs = inv_freq_expanded
1309 .matmul(&position_ids_expanded.to_dtype(inv_freq_expanded.dtype())?)?
1310 .transpose(2, 3)?;
1311 let cos = freqs.cos()?;
1312 let sin = freqs.sin()?;
1313
1314 let cos = Tensor::cat(
1315 &cos.split(&self.mrope_section, D::Minus1)?
1316 .into_iter()
1317 .enumerate()
1318 .map(|(i, m)| m.i(i % 3))
1319 .collect::<Result<Vec<_>>>()?,
1320 D::Minus1,
1321 )?
1322 .squeeze(0)?
1323 .to_dtype(dtype)?
1324 .contiguous()?;
1325 let sin = Tensor::cat(
1326 &sin.split(&self.mrope_section, D::Minus1)?
1327 .into_iter()
1328 .enumerate()
1329 .map(|(i, m)| m.i(i % 3))
1330 .collect::<Result<Vec<_>>>()?,
1331 D::Minus1,
1332 )?
1333 .squeeze(0)?
1334 .to_dtype(dtype)?
1335 .contiguous()?;
1336
1337 Ok((cos, sin))
1338 }
1339
1340 pub fn forward(
1341 &self,
1342 (cos, sin): &(Tensor, Tensor),
1343 q: &mut Tensor,
1344 k: &mut Tensor,
1345 ) -> Result<()> {
1346 *q = candle_nn::rotary_emb::rope(&q.contiguous()?, cos, sin)?;
1347 *k = candle_nn::rotary_emb::rope(&k.contiguous()?, cos, sin)?;
1348 Ok(())
1349 }
1350}
1351
1352#[derive(Debug, Clone)]
1353pub struct DeepSeekV2RotaryEmbedding {
1354 sin: Tensor,
1355 cos: Tensor,
1356}
1357
1358#[derive(Debug, Clone, Deserialize, Serialize)]
1359#[serde(untagged)]
1360pub enum DeepSeekV2RopeScaling {
1361 Yarn {
1362 original_max_position_embeddings: usize,
1363 beta_fast: f32,
1364 beta_slow: f32,
1365 mscale: f32,
1366 mscale_all_dim: f32,
1367 factor: f32,
1368 #[serde(rename = "type")]
1369 scaling_type: ScaledRopeType,
1370 },
1371 LinearOrDynamic {
1372 #[serde(rename = "type")]
1373 scaling_type: ScaledRopeType,
1374 factor: f64,
1375 },
1376}
1377
1378pub struct DeepSeekV2RopeConfig {
1379 pub rope_scaling: Option<DeepSeekV2RopeScaling>,
1380 pub max_position_embeddings: usize,
1381 pub rope_theta: f32,
1382 pub qk_rope_head_dim: usize,
1383}
1384
1385impl DeepSeekV2RotaryEmbedding {
1386 fn new_unscaled(cfg: &DeepSeekV2RopeConfig, dtype: DType, dev: &Device) -> Result<Self> {
1387 let max_seq_len = cfg.max_position_embeddings;
1388 let dim = cfg.qk_rope_head_dim;
1389
1390 let inv_freq: Vec<_> = (0..dim)
1391 .step_by(2)
1392 .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / dim as f32))
1393 .collect();
1394 let inv_freq_len = inv_freq.len();
1395 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
1396 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
1397 .to_dtype(DType::F32)?
1398 .reshape((max_seq_len, 1))?;
1399 let freqs = t.matmul(&inv_freq)?;
1400
1401 let sin = freqs.sin()?.to_dtype(dtype)?;
1402 let cos = freqs.cos()?.to_dtype(dtype)?;
1403
1404 Ok(Self { sin, cos })
1405 }
1406
1407 fn yarn_find_correction_dim(
1408 num_rot: f32,
1409 dim: usize,
1410 base: f32,
1411 max_position_embeddings: usize,
1412 ) -> f32 {
1413 (dim as f32 * (max_position_embeddings as f32 / (num_rot * 2. * PI)).ln())
1414 / (2. * base.ln())
1415 }
1416
1417 fn yarn_find_correction_range(
1418 low_rot: f32,
1419 high_rot: f32,
1420 dim: usize,
1421 base: f32,
1422 max_position_embeddings: usize,
1423 ) -> (f32, f32) {
1424 let low =
1425 Self::yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings).floor();
1426 let high =
1427 Self::yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings).ceil();
1428 (low.max(0.), high.min(dim as f32 - 1.))
1429 }
1430
1431 fn yarn_linear_ramp_mask(min: f32, mut max: f32, dim: usize, dev: &Device) -> Result<Tensor> {
1432 if min == max {
1433 max += 0.001;
1435 }
1436 let linear_func =
1437 ((Tensor::arange(0f32, dim as f32, dev)? - min as f64)? / (max as f64 - min as f64))?;
1438 linear_func.clamp(0., 1)
1439 }
1440
1441 pub(crate) fn yarn_get_mscale(scale: f32, mscale: f32) -> f32 {
1442 if scale <= 1. {
1443 return 1.;
1444 }
1445 0.1 * mscale * scale.ln() + 1.
1446 }
1447
1448 #[allow(clippy::too_many_arguments)]
1449 fn new_yarn(
1450 cfg: &DeepSeekV2RopeConfig,
1451 dtype: DType,
1452 dev: &Device,
1453 original_max_position_embeddings: usize,
1454 beta_fast: f32,
1455 beta_slow: f32,
1456 factor: f32,
1457 mscale: f32,
1458 mscale_all_dim: f32,
1459 ) -> Result<Self> {
1460 let freq_extra: Vec<_> = (0..cfg.qk_rope_head_dim)
1461 .step_by(2)
1462 .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / cfg.qk_rope_head_dim as f32))
1463 .collect();
1464 let freq_extra_len = freq_extra.len();
1465 let freq_extra = Tensor::from_vec(freq_extra, freq_extra_len, dev)?;
1466 let freq_inter: Vec<_> = (0..cfg.qk_rope_head_dim)
1467 .step_by(2)
1468 .map(|i| 1f32 / (factor * cfg.rope_theta.powf(i as f32 / cfg.qk_rope_head_dim as f32)))
1469 .collect();
1470 let freq_inter_len = freq_inter.len();
1471 let freq_inter = Tensor::from_vec(freq_inter, (1, freq_inter_len), dev)?;
1472
1473 let (low, high) = Self::yarn_find_correction_range(
1474 beta_fast,
1475 beta_slow,
1476 cfg.qk_rope_head_dim,
1477 cfg.rope_theta,
1478 original_max_position_embeddings,
1479 );
1480 let inv_freq_mask =
1481 (1. - Self::yarn_linear_ramp_mask(low, high, cfg.qk_rope_head_dim / 2, dev)?)?;
1482 let inv_freq = freq_inter
1483 .broadcast_mul(&(1. - &inv_freq_mask)?)?
1484 .broadcast_add(&freq_extra.broadcast_mul(&inv_freq_mask)?)?;
1485
1486 let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
1487 .to_dtype(DType::F32)?
1488 .reshape((cfg.max_position_embeddings, 1))?;
1489 let freqs = t.matmul(&inv_freq)?;
1490
1491 let mscale =
1492 Self::yarn_get_mscale(factor, mscale) / Self::yarn_get_mscale(factor, mscale_all_dim);
1493 let sin = (freqs.sin()? * mscale as f64)?.to_dtype(dtype)?;
1494 let cos = (freqs.cos()? * mscale as f64)?.to_dtype(dtype)?;
1495
1496 Ok(Self { sin, cos })
1497 }
1498
1499 pub fn new(cfg: &DeepSeekV2RopeConfig, dtype: DType, dev: &Device) -> Result<Self> {
1500 match &cfg.rope_scaling {
1501 Some(DeepSeekV2RopeScaling::LinearOrDynamic {
1502 scaling_type: _,
1503 factor: _,
1504 }) => candle_core::bail!("linear and dynamic rope are not implemented yet!"),
1505 Some(DeepSeekV2RopeScaling::Yarn {
1506 original_max_position_embeddings,
1507 beta_fast,
1508 beta_slow,
1509 factor,
1510 mscale,
1511 mscale_all_dim,
1512 scaling_type: _,
1513 }) => Self::new_yarn(
1514 cfg,
1515 dtype,
1516 dev,
1517 *original_max_position_embeddings,
1518 *beta_fast,
1519 *beta_slow,
1520 *factor,
1521 *mscale,
1522 *mscale_all_dim,
1523 ),
1524 None => Self::new_unscaled(cfg, dtype, dev),
1525 }
1526 }
1527
1528 pub fn forward(
1529 &self,
1530 q: &Tensor,
1531 k: &Tensor,
1532 seqlen_offsets: &[usize],
1533 ) -> Result<(Tensor, Tensor)> {
1534 let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
1535
1536 if seqlen_offsets.len() == 1 {
1537 let cos = self.cos.narrow(0, seqlen_offsets[0], seq_len)?;
1538 let sin = self.sin.narrow(0, seqlen_offsets[0], seq_len)?;
1539 let q_embed = candle_nn::rotary_emb::rope_i(&q.contiguous()?, &cos, &sin)?;
1540 let k_embed = candle_nn::rotary_emb::rope_i(&k.contiguous()?, &cos, &sin)?;
1541 Ok((q_embed, k_embed))
1542 } else {
1543 let mut q_embeds = Vec::new();
1544 let mut k_embeds = Vec::new();
1545 for (i, offset) in seqlen_offsets.iter().enumerate() {
1546 let cos = self.cos.narrow(0, *offset, seq_len)?;
1547 let sin = self.sin.narrow(0, *offset, seq_len)?;
1548 let q_embed = candle_nn::rotary_emb::rope_i(
1549 &q.i(i)?.unsqueeze(0)?.contiguous()?,
1550 &cos,
1551 &sin,
1552 )?;
1553 let k_embed = candle_nn::rotary_emb::rope_i(
1554 &k.i(i)?.unsqueeze(0)?.contiguous()?,
1555 &cos,
1556 &sin,
1557 )?;
1558 q_embeds.push(q_embed);
1559 k_embeds.push(k_embed);
1560 }
1561 Ok((Tensor::cat(&q_embeds, 0)?, Tensor::cat(&k_embeds, 0)?))
1562 }
1563 }
1564}
1565
1566#[derive(Debug, Clone)]
1567pub struct Phi4MMRotaryEmbedding {
1568 short_sin: Tensor,
1569 short_cos: Tensor,
1570 long_cos: Option<Tensor>,
1571 long_sin: Option<Tensor>,
1572 original_max_position_embeddings: usize,
1573}
1574
1575#[derive(Debug, Clone, Default, Deserialize, Serialize)]
1576#[serde(rename_all = "lowercase")]
1577pub enum Phi4MMScaledRopeType {
1578 #[serde(alias = "longrope")]
1579 LongRope,
1580 #[default]
1581 Default,
1582}
1583
1584#[derive(Debug, Clone, Deserialize, Serialize)]
1585pub struct Phi4MMRopeScalingConfig {
1586 short_factor: Option<Vec<f64>>,
1587 long_factor: Option<Vec<f64>>,
1588 #[serde(rename = "type")]
1589 scaling_type: Phi4MMScaledRopeType,
1590}
1591
1592impl Phi4MMRotaryEmbedding {
1593 fn new_unscaled(cfg: &Phi4MMConfig, dtype: DType, dev: &Device) -> Result<Self> {
1594 let max_seq_len = cfg.max_position_embeddings;
1595 let dim = (cfg.head_dim() as f64 * cfg.partial_rotary_factor) as usize;
1596
1597 let inv_freq: Vec<_> = (0..dim)
1598 .step_by(2)
1599 .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
1600 .collect();
1601 let inv_freq_len = inv_freq.len();
1602 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
1603 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
1604 .to_dtype(DType::F32)?
1605 .reshape((max_seq_len, 1))?;
1606 let freqs = t.matmul(&inv_freq)?;
1607 let sin = freqs.sin()?.to_dtype(dtype)?;
1608 let cos = freqs.cos()?.to_dtype(dtype)?;
1609 Ok(Self {
1610 short_cos: cos,
1611 short_sin: sin,
1612 long_cos: None,
1613 long_sin: None,
1614 original_max_position_embeddings: cfg.original_max_position_embeddings,
1615 })
1616 }
1617
1618 #[allow(clippy::too_many_arguments)]
1619 fn new_longrope(
1620 short_factor: &[f64],
1621 long_factor: &[f64],
1622 cfg: &Phi4MMConfig,
1623 dtype: DType,
1624 dev: &Device,
1625 ) -> Result<Self> {
1626 let max_seq_len = cfg.max_position_embeddings;
1627 let dim = (cfg.head_dim() as f64 * cfg.partial_rotary_factor) as usize;
1628
1629 let scale =
1631 cfg.max_position_embeddings as f64 / cfg.original_max_position_embeddings as f64;
1632 let scaling_factor = if scale <= 1.0 {
1633 1.0
1634 } else {
1635 (1.0 + scale.ln() / (cfg.original_max_position_embeddings as f64).ln()).sqrt()
1636 };
1637
1638 let inv_freq_short: Vec<_> = (0..dim)
1640 .step_by(2)
1641 .enumerate()
1642 .map(|(k, i)| {
1643 1f32 / (short_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32
1644 })
1645 .collect();
1646 let inv_freq_len_short = inv_freq_short.len();
1647 let inv_freq_short = Tensor::from_vec(inv_freq_short, (1, inv_freq_len_short), dev)?;
1648 let t_short = Tensor::arange(0u32, max_seq_len as u32, dev)?
1649 .to_dtype(DType::F32)?
1650 .reshape((max_seq_len, 1))?;
1651 let freqs_short = t_short.matmul(&inv_freq_short)?;
1652 let sin_short = (freqs_short.sin()?.to_dtype(dtype)? * scaling_factor)?;
1653 let cos_short = (freqs_short.cos()?.to_dtype(dtype)? * scaling_factor)?;
1654
1655 let inv_freq_long: Vec<_> = (0..dim)
1657 .step_by(2)
1658 .enumerate()
1659 .map(|(k, i)| {
1660 1f32 / (long_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32
1661 })
1662 .collect();
1663 let inv_freq_len_long = inv_freq_long.len();
1664 let inv_freq_long = Tensor::from_vec(inv_freq_long, (1, inv_freq_len_long), dev)?;
1665 let t_long = Tensor::arange(0u32, max_seq_len as u32, dev)?
1666 .to_dtype(DType::F32)?
1667 .reshape((max_seq_len, 1))?;
1668 let freqs_long = t_long.matmul(&inv_freq_long)?;
1669 let sin_long = (freqs_long.sin()?.to_dtype(dtype)? * scaling_factor)?;
1670 let cos_long = (freqs_long.cos()?.to_dtype(dtype)? * scaling_factor)?;
1671
1672 Ok(Self {
1673 short_cos: cos_short,
1674 short_sin: sin_short,
1675 long_cos: Some(cos_long),
1676 long_sin: Some(sin_long),
1677 original_max_position_embeddings: cfg.original_max_position_embeddings,
1678 })
1679 }
1680
1681 pub fn new(dtype: DType, cfg: &Phi4MMConfig, dev: &Device) -> Result<Self> {
1682 match &cfg.rope_scaling {
1683 Some(Phi4MMRopeScalingConfig {
1684 scaling_type: Phi4MMScaledRopeType::LongRope,
1685 short_factor: Some(short_factor),
1686 long_factor: Some(long_factor),
1687 }) => Self::new_longrope(short_factor, long_factor, cfg, dtype, dev),
1688
1689 _ => Self::new_unscaled(cfg, dtype, dev),
1690 }
1691 }
1692
1693 fn get_long_or_short_sin_cos(&self, position_ids: &[usize]) -> (&Tensor, &Tensor) {
1695 if self.long_cos.is_none() {
1696 return (&self.short_sin, &self.short_cos);
1697 }
1698 let seq_len = position_ids.iter().max().unwrap() + 1;
1699 if seq_len > self.original_max_position_embeddings {
1700 (
1701 self.long_sin.as_ref().unwrap(),
1702 self.long_cos.as_ref().unwrap(),
1703 )
1704 } else {
1705 (&self.short_sin, &self.short_cos)
1706 }
1707 }
1708
1709 pub fn forward(
1710 &self,
1711 q: &Tensor,
1712 k: &Tensor,
1713 seqlen_offsets: &[usize],
1714 position_ids: &[usize],
1715 ) -> Result<(Tensor, Tensor)> {
1716 let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
1717 let (sin, cos) = self.get_long_or_short_sin_cos(position_ids);
1718
1719 let rot_dim = cos.dim(D::Minus1)? * 2;
1720 let q_rot = q.narrow(D::Minus1, 0, rot_dim)?;
1721 let q_pass = q.narrow(D::Minus1, rot_dim, q.dim(D::Minus1)? - rot_dim)?;
1722 let k_rot = k.narrow(D::Minus1, 0, rot_dim)?;
1723 let k_pass = k.narrow(D::Minus1, rot_dim, k.dim(D::Minus1)? - rot_dim)?;
1724
1725 let (q_rot, k_rot) = if seqlen_offsets.len() == 1 {
1726 let cos = cos.narrow(0, seqlen_offsets[0], seq_len)?;
1727 let sin = sin.narrow(0, seqlen_offsets[0], seq_len)?;
1728 let q_embed = candle_nn::rotary_emb::rope(&q_rot.contiguous()?, &cos, &sin)?;
1729 let k_embed = candle_nn::rotary_emb::rope(&k_rot.contiguous()?, &cos, &sin)?;
1730 (q_embed, k_embed)
1731 } else {
1732 let mut q_embeds = Vec::new();
1733 let mut k_embeds = Vec::new();
1734 for (i, offset) in seqlen_offsets.iter().enumerate() {
1735 let cos = cos.narrow(0, *offset, seq_len)?;
1736 let sin = sin.narrow(0, *offset, seq_len)?;
1737 let q_embed = candle_nn::rotary_emb::rope(
1738 &q_rot.i(i)?.unsqueeze(0)?.contiguous()?,
1739 &cos,
1740 &sin,
1741 )?;
1742 let k_embed = candle_nn::rotary_emb::rope(
1743 &k_rot.i(i)?.unsqueeze(0)?.contiguous()?,
1744 &cos,
1745 &sin,
1746 )?;
1747 q_embeds.push(q_embed);
1748 k_embeds.push(k_embed);
1749 }
1750 let q_rot = Tensor::cat(&q_embeds, 0)?;
1751 let k_rot = Tensor::cat(&k_embeds, 0)?;
1752 (q_rot, k_rot)
1753 };
1754
1755 Ok((
1756 Tensor::cat(&[q_rot, q_pass], D::Minus1)?.contiguous()?,
1757 Tensor::cat(&[k_rot, k_pass], D::Minus1)?.contiguous()?,
1758 ))
1759 }
1760}
1761
1762#[derive(Debug, Clone)]
1763pub struct Gemma3nRotaryEmbedding(RotaryEmbedding);
1764
1765#[derive(Debug, Clone, Deserialize, Serialize)]
1766#[serde(rename_all = "lowercase")]
1767pub enum Gemma3nScaledRopeType {
1768 #[serde(alias = "linear")]
1769 Linear,
1770}
1771
1772#[derive(Debug, Clone, Deserialize, Serialize)]
1773pub struct Gemma3nRopeScalingConfig {
1774 factor: f64,
1775 rope_type: Gemma3nScaledRopeType,
1776}
1777
1778impl Gemma3nRotaryEmbedding {
1779 fn new_linear(
1780 cfg: &Gemma3nTextConfig,
1781 factor: f64,
1782 is_gpt_neox: bool,
1783 dtype: DType,
1784 dev: &Device,
1785 ) -> Result<Self> {
1786 let max_seq_len = cfg.max_position_embeddings;
1787 let dim = cfg.head_dim;
1788
1789 let inv_freq: Vec<_> = (0..dim)
1790 .step_by(2)
1791 .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
1792 .collect();
1793 let inv_freq_len = inv_freq.len();
1794 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
1795 let inv_freq = (inv_freq / factor)?;
1796
1797 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
1798 .to_dtype(DType::F32)?
1799 .reshape((max_seq_len, 1))?;
1800 let freqs = t.matmul(&inv_freq)?;
1801 let sin = freqs.sin()?.to_dtype(dtype)?;
1802 let cos = freqs.cos()?.to_dtype(dtype)?;
1803 Ok(Self(RotaryEmbedding {
1804 cos,
1805 sin,
1806 is_gpt_neox,
1807 }))
1808 }
1809
1810 pub fn new(
1811 is_gpt_neox: bool,
1812 dtype: DType,
1813 cfg: &Gemma3nTextConfig,
1814 dev: &Device,
1815 ) -> Result<Self> {
1816 match &cfg.rope_scaling {
1817 Some(Gemma3RopeScalingConfig {
1818 rope_type: Gemma3ScaledRopeType::Linear,
1819 factor,
1820 }) => Self::new_linear(cfg, *factor, is_gpt_neox, dtype, dev),
1821
1822 _ => Self::new_linear(cfg, 1.0, is_gpt_neox, dtype, dev),
1823 }
1824 }
1825
1826 pub fn get_cos_sin(&self) -> Result<(Tensor, Tensor)> {
1827 self.0.get_cos_sin()
1828 }
1829
1830 pub fn forward(
1831 &self,
1832 q: &Tensor,
1833 k: &Tensor,
1834 seqlen_offsets: &[usize],
1835 ) -> Result<(Tensor, Tensor)> {
1836 self.0.forward(q, k, seqlen_offsets)
1837 }
1838}
1839
1840#[derive(Debug, Clone)]
1841pub struct Gemma3RotaryEmbedding(RotaryEmbedding);
1842
1843#[derive(Debug, Clone, Deserialize, Serialize)]
1844#[serde(rename_all = "lowercase")]
1845pub enum Gemma3ScaledRopeType {
1846 #[serde(alias = "linear")]
1847 Linear,
1848}
1849
1850#[derive(Debug, Clone, Deserialize, Serialize)]
1851pub struct Gemma3RopeScalingConfig {
1852 factor: f64,
1853 rope_type: Gemma3ScaledRopeType,
1854}
1855
1856impl Gemma3RotaryEmbedding {
1857 fn new_linear(
1858 cfg: &Gemma3TextConfig,
1859 factor: f64,
1860 is_gpt_neox: bool,
1861 dtype: DType,
1862 dev: &Device,
1863 ) -> Result<Self> {
1864 let max_seq_len = cfg.max_position_embeddings;
1865 let dim = cfg.head_dim;
1866
1867 let inv_freq: Vec<_> = (0..dim)
1868 .step_by(2)
1869 .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
1870 .collect();
1871 let inv_freq_len = inv_freq.len();
1872 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
1873 let inv_freq = (inv_freq / factor)?;
1874
1875 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
1876 .to_dtype(DType::F32)?
1877 .reshape((max_seq_len, 1))?;
1878 let freqs = t.matmul(&inv_freq)?;
1879 let sin = freqs.sin()?.to_dtype(dtype)?;
1880 let cos = freqs.cos()?.to_dtype(dtype)?;
1881 Ok(Self(RotaryEmbedding {
1882 cos,
1883 sin,
1884 is_gpt_neox,
1885 }))
1886 }
1887
1888 pub fn new(
1889 is_gpt_neox: bool,
1890 dtype: DType,
1891 cfg: &Gemma3TextConfig,
1892 dev: &Device,
1893 ) -> Result<Self> {
1894 match &cfg.rope_scaling {
1895 Some(Gemma3RopeScalingConfig {
1896 rope_type: Gemma3ScaledRopeType::Linear,
1897 factor,
1898 }) => Self::new_linear(cfg, *factor, is_gpt_neox, dtype, dev),
1899
1900 _ => Self::new_linear(cfg, 1.0, is_gpt_neox, dtype, dev),
1901 }
1902 }
1903
1904 pub fn forward(
1905 &self,
1906 q: &Tensor,
1907 k: &Tensor,
1908 seqlen_offsets: &[usize],
1909 ) -> Result<(Tensor, Tensor)> {
1910 self.0.forward(q, k, seqlen_offsets)
1911 }
1912}
1913
1914pub struct DiaRotaryEmbedding {
1915 timescale: Tensor,
1916 dtype: DType,
1917}
1918
1919impl DiaRotaryEmbedding {
1920 pub fn new(
1921 min_timescale: f32,
1922 max_timescale: f32,
1923 head_dim: usize,
1924 device: &Device,
1925 dtype: DType,
1926 ) -> Result<Self> {
1927 assert_eq!(head_dim % 2, 0);
1928 let half_embedding_dim = head_dim / 2;
1929
1930 let fraction = (0..half_embedding_dim).map(|i| 2f32 * i as f32 / head_dim as f32);
1931 let timescale = fraction
1932 .into_iter()
1933 .map(|x| min_timescale * (max_timescale / min_timescale).powf(x))
1934 .collect::<Vec<_>>();
1935
1936 let timescale_len = timescale.len();
1937 let timescale = Tensor::from_vec(timescale, timescale_len, device)?;
1938
1939 Ok(Self { timescale, dtype })
1940 }
1941
1942 pub fn forward(&self, xs: &Tensor, positions: &Tensor) -> Result<Tensor> {
1943 let freqs = positions
1944 .unsqueeze(D::Minus1)?
1945 .unsqueeze(D::Minus1)?
1946 .broadcast_div(&self.timescale)?;
1947
1948 let sin = freqs.sin()?.to_dtype(self.dtype)?;
1949 let cos = freqs.cos()?.to_dtype(self.dtype)?;
1950
1951 let split = xs.chunk(2, D::Minus1)?;
1952 let first_half = &split[0];
1953 let second_half = &split[1];
1954
1955 let first_part = (first_half.broadcast_mul(&cos)? - second_half.broadcast_mul(&sin)?)?;
1956 let second_part = (second_half.broadcast_mul(&cos)? + first_half.broadcast_mul(&sin)?)?;
1957
1958 Tensor::cat(&[first_part, second_part], D::Minus1)
1959 }
1960}
1961#[derive(Debug, Clone)]
1962pub struct QLinear {
1963 inner: QMatMul,
1964 bias: Option<Tensor>,
1965 dtype: DType,
1966}
1967
1968impl QLinear {
1969 pub fn new<R: std::io::Read + std::io::Seek>(
1970 ct: &mut Content<'_, R>,
1971 name: &str,
1972 device: &Device,
1973 ) -> Result<Self> {
1974 let w = ct.tensor(&format!("{name}.weight"), device)?;
1975 let b = ct.tensor(&format!("{name}.bias"), device)?;
1976 let inner = QMatMul::from_qtensor(w)?;
1977 let bias = b.dequantize(device)?;
1978 Ok(Self {
1979 inner,
1980 bias: Some(bias),
1981 dtype: DType::F32,
1982 })
1983 }
1984
1985 pub fn from_linear(linear: Linear) -> Self {
1986 Self {
1987 inner: QMatMul::Tensor(linear.weight().clone()),
1988 bias: linear.bias().cloned(),
1989 dtype: linear.weight().dtype(),
1990 }
1991 }
1992
1993 pub fn from_parts(w: Tensor, b: Option<Tensor>) -> Self {
1994 let dtype = w.dtype();
1995 Self {
1996 inner: QMatMul::Tensor(w),
1997 bias: b,
1998 dtype,
1999 }
2000 }
2001
2002 pub fn from_qparts(w: QTensor, b: Option<Tensor>) -> Self {
2003 if let Some(ref b) = b {
2004 assert_eq!(b.dtype(), DType::F32);
2005 }
2006 Self {
2007 inner: QMatMul::QTensor(Arc::new(w)),
2008 bias: b,
2009 dtype: DType::F32,
2010 }
2011 }
2012
2013 pub fn from_old_and_qmatmul(inner: QMatMul, old: &Self) -> Self {
2014 Self {
2015 inner,
2016 bias: old.bias.clone(),
2017 dtype: old.dtype,
2018 }
2019 }
2020
2021 pub fn inner(&mut self) -> &mut QMatMul {
2022 &mut self.inner
2023 }
2024
2025 pub fn inner_ref(&self) -> &QMatMul {
2026 &self.inner
2027 }
2028
2029 pub fn is_quant(&self) -> bool {
2030 matches!(self.inner, QMatMul::QTensor(_))
2031 }
2032
2033 pub fn bias(&self) -> Option<&Tensor> {
2034 self.bias.as_ref()
2035 }
2036
2037 pub fn bias_mut(&mut self) -> Option<&mut Tensor> {
2038 self.bias.as_mut()
2039 }
2040}
2041
2042impl Module for QLinear {
2043 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2044 let xs = if self.is_quant() {
2045 xs.to_dtype(DType::F32)?
2046 } else {
2047 xs.clone()
2048 };
2049 if let Some(bias) = &self.bias {
2050 self.inner
2051 .forward(&xs)?
2052 .broadcast_add(bias)?
2053 .to_dtype(self.dtype)
2054 } else {
2055 self.inner.forward(&xs)?.to_dtype(self.dtype)
2056 }
2057 }
2058}
2059
2060#[derive(Debug, Clone)]
2061pub struct RotaryEmbedding {
2062 cos: Tensor,
2063 sin: Tensor,
2064 is_gpt_neox: bool,
2065}
2066
2067impl RotaryEmbedding {
2068 pub fn new(
2069 base: f32,
2070 head_dim: usize,
2071 max_position_embeddings: usize,
2072 device: &Device,
2073 is_gpt_neox: bool,
2074 dtype: DType,
2075 ) -> Result<Self> {
2076 let inv_freq: Vec<_> = (0..head_dim)
2077 .step_by(2)
2078 .map(|i| 1f32 / base.powf(i as f32 / head_dim as f32))
2079 .collect();
2080 let inv_freq_len = inv_freq.len();
2081 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), device)?;
2082 let t = Tensor::arange(0u32, max_position_embeddings as u32, device)?
2083 .to_dtype(DType::F32)?
2084 .reshape((max_position_embeddings, 1))?;
2085 let freqs = t.matmul(&inv_freq)?;
2086 let sin = freqs.sin()?.to_dtype(dtype)?;
2087 let cos = freqs.cos()?.to_dtype(dtype)?;
2088
2089 Ok(Self {
2090 cos,
2091 sin,
2092 is_gpt_neox,
2093 })
2094 }
2095
2096 pub fn get_cos_sin(&self) -> Result<(Tensor, Tensor)> {
2097 Ok((self.cos.clone(), self.sin.clone()))
2098 }
2099
2100 pub fn new_partial(
2101 base: f32,
2102 rot_dim: usize,
2103 max_position_embeddings: usize,
2104 device: &Device,
2105 is_gpt_neox: bool,
2106 dtype: DType,
2107 ) -> Result<Self> {
2108 let inv_freq: Vec<_> = (0..rot_dim)
2109 .step_by(2)
2110 .map(|i| 1f32 / base.powf(i as f32 / rot_dim as f32))
2111 .collect();
2112 let inv_freq_len = inv_freq.len();
2113 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), device)?;
2114 let t = Tensor::arange(0u32, max_position_embeddings as u32, device)?
2115 .to_dtype(DType::F32)?
2116 .reshape((max_position_embeddings, 1))?;
2117 let freqs = t.matmul(&inv_freq)?;
2118 let sin = freqs.sin()?.to_dtype(dtype)?;
2119 let cos = freqs.cos()?.to_dtype(dtype)?;
2120
2121 Ok(Self {
2122 cos,
2123 sin,
2124 is_gpt_neox,
2125 })
2126 }
2127
2128 pub fn forward(
2129 &self,
2130 q: &Tensor,
2131 k: &Tensor,
2132 seqlen_offsets: &[usize],
2133 ) -> Result<(Tensor, Tensor)> {
2134 let (b_sz, qh, seq_len, n_embd) = q.dims4()?;
2135 let (_b_sz, kh, _seq_len, __n_embd) = k.dims4()?;
2136
2137 let rope = if self.is_gpt_neox {
2138 candle_nn::rotary_emb::rope
2139 } else {
2140 candle_nn::rotary_emb::rope_i
2141 };
2142
2143 if cfg!(feature = "cuda") && qh == kh {
2144 let (cos, sin) = if seqlen_offsets.len() == 1 {
2145 (
2146 self.cos.narrow(0, seqlen_offsets[0], seq_len)?,
2147 self.sin.narrow(0, seqlen_offsets[0], seq_len)?,
2148 )
2149 } else {
2150 let mut cos_s = Vec::new();
2151 let mut sin_s = Vec::new();
2152 for offset in seqlen_offsets {
2153 cos_s.push(self.cos.narrow(0, *offset, seq_len)?);
2154 sin_s.push(self.sin.narrow(0, *offset, seq_len)?);
2155 }
2156 (Tensor::cat(&cos_s, 0)?, Tensor::cat(&sin_s, 0)?)
2157 };
2158
2159 let q_embed = q.transpose(1, 2)?.flatten(0, 1)?;
2160 let k_embed = k.transpose(1, 2)?.flatten(0, 1)?;
2161 mistralrs_quant::rotary::apply_rotary_inplace(
2162 &q_embed,
2163 &k_embed,
2164 &cos,
2165 &sin,
2166 self.is_gpt_neox,
2167 )?;
2168 let mut q = q_embed
2169 .reshape((b_sz, seq_len, qh, n_embd))?
2170 .transpose(1, 2)?;
2171 let mut k = k_embed
2172 .reshape((b_sz, seq_len, kh, n_embd))?
2173 .transpose(1, 2)?;
2174 if !(cfg!(feature = "flash-attn") || cfg!(feature = "flash-attn-v3")) {
2175 q = q.contiguous()?;
2176 k = k.contiguous()?;
2177 }
2178 Ok((q, k))
2179 } else if seqlen_offsets.len() == 1 {
2180 let cos = self.cos.narrow(0, seqlen_offsets[0], seq_len)?;
2181 let sin = self.sin.narrow(0, seqlen_offsets[0], seq_len)?;
2182 let q_embed = rope(&q.contiguous()?, &cos, &sin)?;
2183 let k_embed = rope(&k.contiguous()?, &cos, &sin)?;
2184 Ok((q_embed, k_embed))
2185 } else {
2186 let mut q_embeds = Vec::new();
2187 let mut k_embeds = Vec::new();
2188 for (i, offset) in seqlen_offsets.iter().enumerate() {
2189 let cos = self.cos.narrow(0, *offset, seq_len)?;
2190 let sin = self.sin.narrow(0, *offset, seq_len)?;
2191 let q_embed = rope(&q.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
2192 let k_embed = rope(&k.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
2193 q_embeds.push(q_embed);
2194 k_embeds.push(k_embed);
2195 }
2196 Ok((Tensor::cat(&q_embeds, 0)?, Tensor::cat(&k_embeds, 0)?))
2197 }
2198 }
2199}
2200
2201#[derive(Debug, Clone, Copy, PartialEq, Deserialize, Serialize, Default)]
2202#[serde(rename_all = "lowercase")]
2203pub enum Activation {
2204 #[default]
2205 #[serde(alias = "gelu")]
2206 Gelu,
2207 #[serde(alias = "gelu_new")]
2208 NewGelu,
2209 Relu,
2210 Relu2,
2211 Relu6,
2212 Silu,
2213 Sigmoid,
2214 HardSigmoid,
2215 Swiglu,
2216 Swish,
2217 HardSwish,
2218 Elu(f64),
2219 LeakyRelu(f64),
2220 #[serde(alias = "gelu_pytorch_tanh")]
2221 GeluPytorchTanh,
2222 QuickGelu,
2223}
2224
2225impl Module for Activation {
2226 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2227 match self {
2228 Self::Gelu => xs.gelu_erf(),
2229 Self::NewGelu => xs.gelu(),
2231 Self::Relu => xs.relu(),
2232 Self::Relu2 => xs.relu()?.sqr(),
2233 Self::Relu6 => xs.clamp(0f32, 6f32),
2234 Self::Silu => xs.silu(),
2235 Self::Sigmoid => candle_nn::ops::sigmoid(xs),
2236 Self::HardSigmoid => candle_nn::ops::hard_sigmoid(xs),
2237 Self::Swiglu => candle_nn::ops::swiglu(xs),
2238 Self::Swish => xs * candle_nn::ops::sigmoid(xs)?,
2239 Self::HardSwish => xs * candle_nn::ops::hard_sigmoid(xs)?,
2240 &Self::Elu(alpha) => xs.elu(alpha),
2241 &Self::LeakyRelu(negative_slope) => candle_nn::ops::leaky_relu(xs, negative_slope),
2242 Self::GeluPytorchTanh => xs.gelu(),
2243 Self::QuickGelu => xs * candle_nn::ops::sigmoid(&(xs * 1.702f64)?),
2244 }
2245 }
2246}
2247
2248impl TryInto<candle_nn::Activation> for Activation {
2249 type Error = candle_core::Error;
2250
2251 fn try_into(self) -> Result<candle_nn::Activation> {
2252 match self {
2253 Self::Gelu => Ok(candle_nn::Activation::Gelu),
2254 Self::Relu => Ok(candle_nn::Activation::Relu),
2255 Self::Silu => Ok(candle_nn::Activation::Silu),
2256 Self::NewGelu => Ok(candle_nn::Activation::NewGelu),
2257 Self::Relu2 => Ok(candle_nn::Activation::Relu2),
2258 Self::Relu6 => Ok(candle_nn::Activation::Relu6),
2259 Self::Sigmoid => Ok(candle_nn::Activation::Sigmoid),
2260 Self::HardSigmoid => Ok(candle_nn::Activation::HardSigmoid),
2261 Self::Swiglu => Ok(candle_nn::Activation::Swiglu),
2262 Self::Swish => Ok(candle_nn::Activation::Swish),
2263 Self::HardSwish => Ok(candle_nn::Activation::HardSwish),
2264 Self::Elu(x) => Ok(candle_nn::Activation::Elu(x)),
2265 Self::LeakyRelu(x) => Ok(candle_nn::Activation::LeakyRelu(x)),
2266 Self::GeluPytorchTanh => Ok(candle_nn::Activation::GeluPytorchTanh),
2267 Self::QuickGelu => candle_core::bail!("No mapping to candle_nn for QuickGelu"),
2268 }
2269 }
2270}
2271
2272#[derive(Debug, Clone, Copy, PartialEq, Eq)]
2273pub struct Conv3dConfig {
2274 pub padding: usize,
2275 pub stride: usize,
2276 pub dilation: usize,
2277 pub groups: usize,
2278}
2279
2280impl Default for Conv3dConfig {
2281 fn default() -> Self {
2282 Self {
2283 padding: 0,
2284 stride: 1,
2285 dilation: 1,
2286 groups: 1,
2287 }
2288 }
2289}
2290
2291pub struct Conv3dNoBias {
2292 conv2d_1: Conv2d,
2293 conv2d_2: Conv2d,
2294}
2295
2296impl Conv3dNoBias {
2297 pub fn new(
2298 in_channels: usize,
2299 out_channels: usize,
2300 kernel_sizes: [usize; 3],
2301 cfg: Conv3dConfig,
2302 vb: ShardedVarBuilder,
2303 ) -> Result<Self> {
2304 let ws = vb.get(
2305 (
2306 out_channels,
2307 in_channels / cfg.groups,
2308 kernel_sizes[0],
2309 kernel_sizes[1],
2310 kernel_sizes[2],
2311 ),
2312 "weight",
2313 )?;
2314
2315 let w1 = ws.i((.., .., 0, .., ..))?;
2319 let w2 = ws.i((.., .., 1, .., ..))?;
2320
2321 let cfg = Conv2dConfig {
2322 padding: cfg.padding,
2323 stride: cfg.stride,
2324 dilation: cfg.dilation,
2325 groups: cfg.groups,
2326 cudnn_fwd_algo: None,
2327 };
2328
2329 Ok(Self {
2330 conv2d_1: Conv2d::new(w1.contiguous()?, None, cfg),
2331 conv2d_2: Conv2d::new(w2.contiguous()?, None, cfg),
2332 })
2333 }
2334
2335 pub fn weight(&self) -> Result<Tensor> {
2336 let w1 = self.conv2d_1.weight().clone().unsqueeze(2)?;
2337 let w2 = self.conv2d_2.weight().clone().unsqueeze(2)?;
2338 Tensor::cat(&[w1, w2], 2)
2339 }
2340}
2341
2342impl Module for Conv3dNoBias {
2343 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2344 let xs1 = xs.i((.., .., 0, .., ..))?;
2345 let xs2 = xs.i((.., .., 1, .., ..))?;
2346
2347 (Convolution.forward_2d(&self.conv2d_1, &xs1)?
2348 + Convolution.forward_2d(&self.conv2d_2, &xs2)?)?
2349 .unsqueeze(2)
2350 }
2351}
2352
2353pub trait TensorInfExtend {
2354 fn is_inf(&self) -> Result<Self>
2355 where
2356 Self: Sized;
2357 fn any(&self) -> Result<bool>;
2358}
2359
2360impl TensorInfExtend for Tensor {
2361 fn is_inf(&self) -> Result<Self> {
2362 self.broadcast_eq(&Tensor::new(f32::INFINITY, self.device())?.to_dtype(self.dtype())?)
2363 }
2364
2365 fn any(&self) -> Result<bool> {
2366 let sum = self.sum_all()?;
2367 match self.dtype() {
2368 DType::U8 => Ok(sum.to_scalar::<u8>()? == 0),
2369 DType::U32 => Ok(sum.to_scalar::<u32>()? == 0),
2370 DType::I16 => Ok(sum.to_scalar::<i16>()? == 0),
2371 DType::I32 => Ok(sum.to_scalar::<i32>()? == 0),
2372 DType::I64 => Ok(sum.to_scalar::<i64>()? == 0),
2373 DType::F16 => Ok(sum.to_scalar::<half::f16>()? == half::f16::from_f32_const(0.)),
2374 DType::BF16 => Ok(sum.to_scalar::<half::bf16>()? == half::bf16::from_f32_const(0.)),
2375 DType::F32 => Ok(sum.to_scalar::<f32>()? == 0.),
2376 DType::F64 => Ok(sum.to_scalar::<f64>()? == 0.),
2377 DType::F8E4M3 => Ok(sum.to_scalar::<F8E4M3>()? == F8E4M3::ZERO),
2378 DType::F4 | DType::F6E3M2 | DType::F6E2M3 | DType::F8E8M0 => {
2379 candle_core::bail!("f4/f6e3m2/f6e2m3/f8e8m0 tensors are not supported with .any")
2380 }
2381 }
2382 }
2383}
2384
2385pub fn clamp_for_f16(xs: &Tensor) -> Result<Tensor> {
2386 let mut max = match xs.dtype() {
2387 DType::U8 => u8::MAX as f32 - 1000.,
2388 DType::U32 => u32::MAX as f32 - 1000.,
2389 DType::I16 => i16::MAX as f32 - 1000.,
2390 DType::I32 => i32::MAX as f32 - 1000.,
2391 DType::I64 => i64::MAX as f32 - 1000.,
2392 DType::F16 => half::f16::MAX.to_f32_const() - 1000.,
2393 DType::BF16 => half::bf16::MAX.to_f32_const() - 1000.,
2394 DType::F32 => f32::MAX - 1000.,
2395 DType::F64 => f64::MAX as f32 - 1000.,
2396 DType::F8E4M3 => F8E4M3::MAX.to_f32() - 1000.,
2397 DType::F4 | DType::F6E3M2 | DType::F6E2M3 | DType::F8E8M0 => {
2398 candle_core::bail!("f4/f6e3m2/f6e2m3/f8e8m0 tensors are not supported with .any")
2399 }
2400 };
2401 if xs.is_inf()?.any()? {
2402 max -= 1000.;
2403 }
2404 xs.clamp(-max, max)
2405}
2406
2407pub struct FloatInfo {
2408 pub min: f64,
2410 pub max: f64,
2412 pub eps: f64,
2414 pub dtype: DType,
2415}
2416
2417pub trait GetFloatInfo {
2418 fn finfo(&self) -> Result<FloatInfo>;
2419}
2420
2421impl GetFloatInfo for DType {
2422 fn finfo(&self) -> Result<FloatInfo> {
2423 let finfo = match self {
2424 Self::BF16 => FloatInfo {
2425 min: bf16::MIN.to_f64(),
2426 max: bf16::MAX.to_f64(),
2427 eps: bf16::EPSILON.to_f64(),
2428 dtype: DType::BF16,
2429 },
2430 Self::F16 => FloatInfo {
2431 min: f16::MIN.to_f64(),
2432 max: f16::MAX.to_f64(),
2433 eps: f16::EPSILON.to_f64(),
2434 dtype: DType::F16,
2435 },
2436 Self::F32 => FloatInfo {
2437 min: f32::MIN as f64,
2438 max: f32::MAX as f64,
2439 eps: f32::EPSILON as f64,
2440 dtype: DType::F32,
2441 },
2442 Self::F64 => FloatInfo {
2443 min: f64::MIN,
2444 max: f64::MAX,
2445 eps: f64::EPSILON,
2446 dtype: DType::F64,
2447 },
2448 Self::F8E4M3 => FloatInfo {
2449 min: F8E4M3::MIN.to_f64(),
2450 max: F8E4M3::MAX.to_f64(),
2451 eps: F8E4M3::EPSILON.to_f64(),
2452 dtype: DType::F8E4M3,
2453 },
2454 other => {
2455 candle_core::bail!("Expected a float type for `GetFloatInfo`, got {other:?}");
2456 }
2457 };
2458 Ok(finfo)
2459 }
2460}
2461
2462#[derive(Clone)]
2463pub struct Mlp {
2464 pub gate: Arc<dyn QuantMethod>,
2465 pub up: Arc<dyn QuantMethod>,
2466 pub down: Arc<dyn QuantMethod>,
2467 act: Activation,
2468 params: Vec<usize>,
2469}
2470
2471impl Mlp {
2472 pub fn new(
2473 vb: ShardedVarBuilder,
2474 hidden_size: usize,
2475 intermediate_size: usize,
2476 quantization_config: &Option<QuantizedConfig>,
2477 hidden_act: Activation,
2478 comm: &Arc<mistralrs_quant::Comm>,
2479 ) -> Result<Self> {
2480 Ok(Self {
2481 gate: ColumnParallelLayer::new(
2482 hidden_size,
2483 intermediate_size,
2484 quantization_config,
2485 false,
2486 comm,
2487 vb.pp("gate_proj"),
2488 )?,
2489 up: ColumnParallelLayer::new(
2490 hidden_size,
2491 intermediate_size,
2492 quantization_config,
2493 false,
2494 comm,
2495 vb.pp("up_proj"),
2496 )?,
2497 down: RowParallelLayer::new(
2498 intermediate_size,
2499 hidden_size,
2500 quantization_config,
2501 false,
2502 comm,
2503 vb.pp("down_proj"),
2504 )?,
2505 act: hidden_act,
2506 params: vec![hidden_size, intermediate_size],
2507 })
2508 }
2509
2510 pub fn new_merged(
2511 vb: ShardedVarBuilder,
2512 hidden_size: usize,
2513 intermediate_size: usize,
2514 chunks: usize,
2515 quantization_config: &Option<QuantizedConfig>,
2516 hidden_act: Activation,
2517 comm: &Arc<mistralrs_quant::Comm>,
2518 ) -> Result<Self> {
2519 assert!(chunks == 2, "Only gate_up_proj merge is supported!");
2520 let gate_up_projs = ColumnParallelLayer::new_merged(
2521 hidden_size,
2522 intermediate_size * 2,
2523 2,
2524 quantization_config,
2525 false,
2526 comm,
2527 vb.pp("gate_up_proj"),
2528 )?;
2529
2530 Ok(Self {
2531 gate: gate_up_projs[0].to_owned(),
2532 up: gate_up_projs[1].to_owned(),
2533 down: RowParallelLayer::new(
2534 intermediate_size,
2535 hidden_size,
2536 quantization_config,
2537 false,
2538 comm,
2539 vb.pp("down_proj"),
2540 )?,
2541 act: hidden_act,
2542 params: vec![hidden_size, intermediate_size],
2543 })
2544 }
2545
2546 pub fn replicate(
2547 params: &[usize],
2548 vb: ShardedVarBuilder,
2549 act: Activation,
2550 comm: &Arc<mistralrs_quant::Comm>,
2551 ) -> Result<Self> {
2552 Self::new(vb, params[0], params[1], &None, act, comm)
2553 }
2554
2555 pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2556 let original_dtype = xs.dtype();
2557 let mut xs = xs.clone();
2558 if let Some(t) = self.gate.quantized_act_type() {
2559 xs = xs.to_dtype(t)?;
2560 }
2561 let lhs = self.gate.forward(&xs)?;
2562 let rhs = self.up.forward(&xs)?;
2563 let mut res = self
2564 .down
2565 .forward(&crate::ops::mul_and_act(&lhs, &rhs, self.act)?)?;
2566 if self.gate.quantized_act_type().is_some() {
2567 res = res.to_dtype(original_dtype)?;
2568 }
2569 Ok(res)
2570 }
2571}
2572
2573impl AnyMoeTrainableLayer for Mlp {}
2574
2575impl MlpLayer for Mlp {
2576 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2577 let original_dtype = xs.dtype();
2578 let mut xs = xs.clone();
2579 if let Some(t) = self.gate.quantized_act_type() {
2580 xs = xs.to_dtype(t)?;
2581 }
2582 let lhs = MatMul.qmethod_matmul(&xs, &*self.gate)?;
2583 let rhs = MatMul.qmethod_matmul(&xs, &*self.up)?;
2584 let mut res =
2585 MatMul.qmethod_matmul(&crate::ops::mul_and_act(&lhs, &rhs, self.act)?, &*self.down)?;
2586 if self.gate.quantized_act_type().is_some() {
2587 res = res.to_dtype(original_dtype)?;
2588 }
2589 Ok(res)
2590 }
2591 fn get_isq_layers(&mut self) -> Vec<&mut Arc<dyn QuantMethod>> {
2592 vec![&mut self.gate, &mut self.up, &mut self.down]
2593 }
2594 fn clone(&self) -> Box<dyn MlpLayer> {
2595 Box::new(Clone::clone(self))
2596 }
2597 fn get_params(&self) -> &[usize] {
2598 &self.params
2599 }
2600 fn hidden_act(&self) -> Activation {
2601 self.act
2602 }
2603 fn new_added_delta(&self, deltas: Vec<Option<Tensor>>) -> Result<Box<dyn MlpLayer>> {
2605 let gate = if let Some(ref delta) = deltas[0] {
2606 self.gate.add_delta_w(delta)?
2607 } else {
2608 self.gate.clone()
2609 };
2610 let up = if let Some(ref delta) = deltas[1] {
2611 self.up.add_delta_w(delta)?
2612 } else {
2613 self.up.clone()
2614 };
2615 let down = if let Some(ref delta) = deltas[2] {
2616 self.down.add_delta_w(delta)?
2617 } else {
2618 self.down.clone()
2619 };
2620
2621 Ok(Box::new(Self {
2622 gate,
2623 up,
2624 down,
2625 act: self.act,
2626 params: self.params.clone(),
2627 }))
2628 }
2629
2630 fn dtype_device(&self) -> (DType, Device) {
2631 self.gate.dtype_and_device()
2632 }
2633}
2634
2635pub struct AvgPool2d {
2636 kernel_size: usize,
2637 stride: usize,
2638}
2639
2640impl AvgPool2d {
2641 pub fn new(kernel_size: usize, stride: usize) -> Self {
2642 Self {
2643 kernel_size,
2644 stride,
2645 }
2646 }
2647
2648 pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2649 xs.avg_pool2d_with_stride(self.kernel_size, self.stride)
2650 }
2651}
2652
2653pub struct ReflectionPad2d {
2660 padding: (usize, usize, usize, usize),
2661}
2662
2663impl ReflectionPad2d {
2664 pub fn new(padding: (usize, usize, usize, usize)) -> Self {
2665 Self { padding }
2666 }
2667}
2668
2669impl Module for ReflectionPad2d {
2670 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2671 let (pad_left, pad_right, pad_top, pad_bottom) = self.padding;
2672
2673 let (_n, _c, h, w) = xs.dims4()?;
2674
2675 let left_pad = if pad_left > 0 {
2678 let indices: Vec<i64> = (1..=pad_left as i64).rev().collect();
2680 Some(xs.index_select(&Tensor::new(indices, &Device::Cpu)?, 3)?)
2681 } else {
2682 None
2683 };
2684
2685 let right_pad = if pad_right > 0 {
2687 let start = w as i64 - 2;
2689 let indices: Vec<i64> = (0..pad_right as i64).map(|i| start - i).collect();
2690 Some(xs.index_select(&Tensor::new(indices, &Device::Cpu)?, 3)?)
2691 } else {
2692 None
2693 };
2694
2695 let x_padded_width = match (left_pad, right_pad) {
2697 (Some(l), Some(r)) => Tensor::cat(&[l, xs.clone(), r], 3)?,
2698 (Some(l), None) => Tensor::cat(&[l, xs.clone()], 3)?,
2699 (None, Some(r)) => Tensor::cat(&[xs.clone(), r], 3)?,
2700 (None, None) => xs.clone(),
2701 };
2702
2703 let top_pad = if pad_top > 0 {
2706 let indices: Vec<i64> = (1..=pad_top as i64).rev().collect();
2707 Some(x_padded_width.index_select(&Tensor::new(indices, &Device::Cpu)?, 2)?)
2708 } else {
2709 None
2710 };
2711
2712 let bottom_pad = if pad_bottom > 0 {
2714 let start = h as i64 - 2;
2715 let indices: Vec<i64> = (0..pad_bottom as i64).map(|i| start - i).collect();
2716 Some(x_padded_width.index_select(&Tensor::new(indices, &Device::Cpu)?, 2)?)
2717 } else {
2718 None
2719 };
2720
2721 let x_padded = match (top_pad, bottom_pad) {
2723 (Some(t), Some(b)) => Tensor::cat(&[t, x_padded_width, b], 2)?,
2724 (Some(t), None) => Tensor::cat(&[t, x_padded_width], 2)?,
2725 (None, Some(b)) => Tensor::cat(&[x_padded_width, b], 2)?,
2726 (None, None) => x_padded_width,
2727 };
2728
2729 Ok(x_padded)
2730 }
2731}
2732
2733pub struct ScaledEmbedding {
2734 scale: f64,
2735 pub embedding: Tensor,
2736}
2737
2738impl ScaledEmbedding {
2739 pub fn new(scale: f64, embedding: Embedding) -> Self {
2740 Self {
2741 scale,
2742 embedding: embedding.embeddings().clone(),
2743 }
2744 }
2745
2746 pub fn embeddings(&self) -> &Tensor {
2747 &self.embedding
2748 }
2749}
2750
2751impl Module for ScaledEmbedding {
2752 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2753 let embedding = Embedding::new(self.embedding.clone(), self.embedding.dim(D::Minus1)?);
2754 xs.apply(&embedding)? * self.scale
2755 }
2756}