1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{f32::consts::PI, ops::Mul, str::FromStr, sync::Arc};
4
5use candle_core::{
6 quantized::{QMatMul, QTensor},
7 Context, DType, Device, IndexOp, Result, Tensor, D,
8};
9use candle_nn::{
10 BatchNorm, BatchNormConfig, Conv1d, Conv1dConfig, Conv2d, Conv2dConfig, Embedding, GroupNorm,
11 LayerNorm, LayerNormConfig, Linear, Module,
12};
13use float8::F8E4M3;
14use half::{bf16, f16};
15use mistralrs_quant::{
16 AfqLayer, ColumnParallelLayer, QuantMethod, QuantizedConfig, RowParallelLayer,
17 ShardedVarBuilder,
18};
19use serde::{Deserialize, Serialize};
20
21pub use crate::attention::Sdpa;
22pub use crate::layers_masker::CausalMasker;
23pub use crate::layers_utils::repeat_kv;
24use crate::{
25 amoe::{AnyMoeTrainableLayer, MlpLayer},
26 gguf::Content,
27 models::{llama, smollm3},
28 ops::SplitOp,
29 vision_models::{
30 gemma3::config::Gemma3TextConfig,
31 gemma3n::config::Gemma3nTextConfig,
32 llama4,
33 mllama::{MLlamaRopeScaling, MLlamaRopeType, MLlamaTextConfig},
34 phi4::Phi4MMConfig,
35 },
36};
37
38pub use mistralrs_quant::MatMul;
39
40pub fn embedding(
41 in_size: usize,
42 out_size: usize,
43 vb: ShardedVarBuilder,
44 config: &Option<QuantizedConfig>,
45) -> Result<Embedding> {
46 let embeddings = if let Some(QuantizedConfig::Afq { .. }) = config {
48 let afq_layer =
49 AfqLayer::afq_linear_b(out_size, in_size, config.as_ref().unwrap(), false, vb)?;
50 afq_layer.dequantize_w()?
51 } else {
52 vb.get_with_hints((in_size, out_size), "weight", Default::default())?
53 };
54 Ok(Embedding::new(embeddings, out_size))
55}
56
57pub fn layer_norm<C: Into<LayerNormConfig>>(
58 size: usize,
59 config: C,
60 vb: ShardedVarBuilder,
61) -> Result<LayerNorm> {
62 let config = config.into();
63 let weight = vb.get(size, "weight")?;
64 if config.affine {
65 let bias = vb.get(size, "bias")?;
66 Ok(LayerNorm::new(weight, bias, config.eps))
67 } else {
68 Ok(LayerNorm::new_no_bias(weight, config.eps))
69 }
70}
71
72pub fn batch_norm<C: Into<BatchNormConfig>>(
73 num_features: usize,
74 config: C,
75 vb: ShardedVarBuilder,
76) -> Result<BatchNorm> {
77 let config = config.into();
78 if config.eps < 0. {
79 candle_core::bail!("batch-norm eps cannot be negative {}", config.eps)
80 }
81 let running_mean = vb.get(num_features, "running_mean")?;
82 let running_var = vb.get(num_features, "running_var")?;
83
84 if config.affine {
85 let weight = vb.get(num_features, "weight")?;
86 let bias = vb.get(num_features, "bias")?;
87 BatchNorm::new(
88 num_features,
89 running_mean,
90 running_var,
91 weight,
92 bias,
93 config.eps,
94 )
95 } else {
96 BatchNorm::new_no_bias(num_features, running_mean, running_var, config.eps)
97 }
98}
99
100pub fn group_norm(
101 num_groups: usize,
102 num_channels: usize,
103 eps: f64,
104 vb: ShardedVarBuilder,
105) -> Result<GroupNorm> {
106 let weight = vb.get(num_channels, "weight")?;
107 let bias = vb.get(num_channels, "bias")?;
108 GroupNorm::new(weight, bias, num_channels, num_groups, eps)
109}
110
111pub fn conv2d(
112 in_channels: usize,
113 out_channels: usize,
114 kernel_size: usize,
115 cfg: Conv2dConfig,
116 vb: ShardedVarBuilder,
117) -> Result<Conv2d> {
118 let ws = vb.get(
119 (
120 out_channels,
121 in_channels / cfg.groups,
122 kernel_size,
123 kernel_size,
124 ),
125 "weight",
126 )?;
127 let bs = vb.get(out_channels, "bias")?;
128 Ok(Conv2d::new(ws, Some(bs), cfg))
129}
130
131pub fn conv2d_no_bias(
132 in_channels: usize,
133 out_channels: usize,
134 kernel_size: usize,
135 cfg: Conv2dConfig,
136 vb: ShardedVarBuilder,
137) -> Result<Conv2d> {
138 let ws = vb.get(
139 (
140 out_channels,
141 in_channels / cfg.groups,
142 kernel_size,
143 kernel_size,
144 ),
145 "weight",
146 )?;
147 Ok(Conv2d::new(ws, None, cfg))
148}
149
150pub fn conv1d(
151 in_channels: usize,
152 out_channels: usize,
153 kernel_size: usize,
154 cfg: Conv1dConfig,
155 vb: ShardedVarBuilder,
156) -> Result<Conv1d> {
157 let ws = vb.get(
158 (out_channels, in_channels / cfg.groups, kernel_size),
159 "weight",
160 )?;
161 let bs = vb.get(out_channels, "bias")?;
162 Ok(Conv1d::new(ws, Some(bs), cfg))
163}
164
165pub fn conv1d_no_bias(
166 in_channels: usize,
167 out_channels: usize,
168 kernel_size: usize,
169 cfg: Conv1dConfig,
170 vb: ShardedVarBuilder,
171) -> Result<Conv1d> {
172 let ws = vb.get(
173 (out_channels, in_channels / cfg.groups, kernel_size),
174 "weight",
175 )?;
176 Ok(Conv1d::new(ws, None, cfg))
177}
178
179pub fn linear(in_dim: usize, out_dim: usize, vb: ShardedVarBuilder) -> Result<Linear> {
180 let ws = vb.get((out_dim, in_dim), "weight")?;
181 let bs = vb.get(out_dim, "bias")?;
182 Ok(Linear::new(ws, Some(bs)))
183}
184
185pub fn linear_no_bias(in_dim: usize, out_dim: usize, vb: ShardedVarBuilder) -> Result<Linear> {
186 let ws = vb.get((out_dim, in_dim), "weight")?;
187 Ok(Linear::new(ws, None))
188}
189
190pub fn linear_b(
191 in_dim: usize,
192 out_dim: usize,
193 bias: bool,
194 vb: ShardedVarBuilder,
195) -> Result<Linear> {
196 if bias {
197 linear(in_dim, out_dim, vb)
198 } else {
199 linear_no_bias(in_dim, out_dim, vb)
200 }
201}
202
203#[derive(Debug, Clone)]
204pub struct RmsNorm {
205 eps: f64,
206 weight: Tensor,
207}
208
209impl RmsNorm {
210 pub fn new(size: usize, eps: f64, vb: ShardedVarBuilder) -> Result<Self> {
211 let w = vb.get(size, "weight")?;
212 Ok(Self { eps, weight: w })
213 }
214
215 pub fn new_gemma(size: usize, eps: f64, vb: ShardedVarBuilder) -> Result<Self> {
217 let w = vb.get(size, "weight")?;
218 let w = (w + 1.0)?;
219 Ok(Self { eps, weight: w })
220 }
221
222 pub fn new_gemma_3n(
224 size: usize,
225 eps: f64,
226 with_scale: bool,
227 vb: ShardedVarBuilder,
228 ) -> Result<Self> {
229 let w = if with_scale {
230 vb.get(size, "weight")?
231 } else {
232 Tensor::ones(size, vb.dtype(), vb.device())?
233 };
234 Ok(Self { eps, weight: w })
235 }
236
237 pub fn undo_gemma(&self) -> Result<Self> {
239 Ok(Self {
240 eps: self.eps,
241 weight: (&self.weight - 1.0)?,
242 })
243 }
244
245 pub fn from_w(w: Tensor, eps: f64) -> Result<Self> {
246 Ok(Self { eps, weight: w })
247 }
248
249 pub fn weight(&self) -> &Tensor {
250 &self.weight
251 }
252}
253
254impl Module for RmsNorm {
255 fn forward(&self, x: &Tensor) -> Result<Tensor> {
256 candle_nn::ops::rms_norm(&x.contiguous()?, &self.weight, self.eps as f32)
257 }
258}
259
260#[derive(Debug, Clone)]
261pub struct F32RmsNorm {
262 w: Tensor,
263 eps: f64,
264}
265
266impl F32RmsNorm {
267 pub fn new(size: usize, eps: f64, vb: ShardedVarBuilder) -> Result<Self> {
268 Ok(Self {
269 w: vb.get((size,), "weight")?,
270 eps,
271 })
272 }
273
274 pub fn weight(&self) -> &Tensor {
275 &self.w
276 }
277}
278
279impl Module for F32RmsNorm {
280 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
281 let initial_type = xs.dtype();
282 let mut xs = xs.to_dtype(DType::F32)?;
283 let var = xs.powf(2.)?.mean_keepdim(D::Minus1)?;
284 xs = xs.broadcast_mul(&(&var + self.eps)?.recip()?.sqrt()?)?;
285 xs.to_dtype(initial_type)?.broadcast_mul(&self.w)
286 }
287}
288
289#[derive(Debug, Clone)]
290pub struct QRmsNorm {
291 eps: f64,
292 weight: Tensor,
293}
294
295impl QRmsNorm {
296 pub fn new(scale: QTensor, eps: f32) -> Result<Self> {
297 let scale = scale.dequantize(&scale.device())?;
298 Ok(Self {
299 eps: eps as f64,
300 weight: scale,
301 })
302 }
303
304 pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
305 candle_nn::ops::rms_norm(&x.contiguous()?, &self.weight, self.eps as f32)
306 }
307}
308
309#[derive(Debug, Clone)]
311pub struct PhiRotaryEmbedding {
312 short_sin: Tensor,
313 short_cos: Tensor,
314 long_cos: Option<Tensor>,
315 long_sin: Option<Tensor>,
316 original_max_position_embeddings: usize,
317}
318
319#[derive(Debug, Clone, Deserialize, Serialize)]
320#[serde(rename_all = "lowercase")]
321pub enum ScaledRopeType {
322 #[serde(alias = "su")]
323 #[serde(alias = "longrope")]
324 Su,
325 #[serde(alias = "yarn")]
326 Yarn,
327 #[serde(alias = "dynamic")]
328 Dynamic,
329 #[serde(alias = "linear")]
330 Linear,
331}
332
333impl FromStr for ScaledRopeType {
334 type Err = candle_core::Error;
335 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
336 match s {
337 "su" | "longrope" => Ok(Self::Su),
338 "yarn" => Ok(Self::Yarn),
339 "linear" => Ok(Self::Linear),
340 "dynamic" => Ok(Self::Dynamic),
341 _ => Err(candle_core::Error::Msg(
342 "Expected either `su` or `yarn` scaled RoPE type.".to_string(),
343 )),
344 }
345 }
346}
347
348#[derive(Debug, Clone, Deserialize, Serialize)]
349#[serde(untagged)]
350pub enum PhiRopeScalingConfig {
351 Classic {
352 short_factor: Vec<f64>,
353 long_factor: Vec<f64>,
354 #[serde(rename = "type")]
355 scaling_type: ScaledRopeType,
356 },
357 Scaled {
358 short_factor: Vec<f64>,
359 long_factor: Vec<f64>,
360 #[serde(rename = "type")]
361 scaling_type: ScaledRopeType,
362 long_mscale: f64,
363 short_mscale: f64,
364 },
365}
366
367pub struct PhiRopeConfig {
368 pub rope_scaling: Option<PhiRopeScalingConfig>,
369 pub max_position_embeddings: usize,
370 pub original_max_position_embeddings: usize,
371 pub rope_theta: f64,
372 pub head_dim: usize,
373 pub partial_rotary_factor: Option<f64>,
374}
375
376impl PhiRotaryEmbedding {
377 fn new_classic_scaled(
378 short_factor: &[f64],
379 long_factor: &[f64],
380 scaling_type: &ScaledRopeType,
381 cfg: &PhiRopeConfig,
382 dtype: DType,
383 dev: &Device,
384 ) -> Result<Self> {
385 let max_seq_len = cfg.max_position_embeddings;
386 let dim = (cfg.head_dim as f64 * cfg.partial_rotary_factor.unwrap_or(1.)) as usize;
387
388 let scale =
390 cfg.max_position_embeddings as f64 / cfg.original_max_position_embeddings as f64;
391 let scaling_factor = if scale <= 1.0 {
392 1.0
393 } else {
394 match scaling_type {
395 ScaledRopeType::Su => {
396 (1.0 + scale.ln() / (cfg.original_max_position_embeddings as f64).ln()).sqrt()
397 }
398 ScaledRopeType::Yarn => 0.1 * scale.ln() + 1.0,
399 _ => candle_core::bail!("Expected either `su` or `yarn` RoPE"),
400 }
401 };
402
403 let inv_freq_long = (0..dim)
405 .step_by(2)
406 .enumerate()
407 .map(|(k, i)| {
408 (1f64 / (long_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64))) as f32
409 })
410 .collect::<Vec<_>>();
411 let inv_freq_short = (0..dim)
412 .step_by(2)
413 .enumerate()
414 .map(|(k, i)| {
415 (1f64 / (short_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64))) as f32
416 })
417 .collect::<Vec<_>>();
418 let inv_freq_len = inv_freq_long.len();
419
420 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
421 .to_dtype(DType::F32)?
422 .reshape((max_seq_len, 1))?;
423
424 let inv_freq_long = Tensor::from_vec(inv_freq_long, (1, inv_freq_len), dev)?;
426 let freqs_long = t.matmul(&inv_freq_long)?;
427 let long_sin = freqs_long.sin()?.mul(scaling_factor)?.to_dtype(dtype)?;
428 let long_cos = freqs_long.cos()?.mul(scaling_factor)?.to_dtype(dtype)?;
429
430 let inv_freq_short =
432 Tensor::from_vec(inv_freq_short, (1, inv_freq_len), dev)?.to_dtype(DType::F32)?;
433 let freqs_short = t.matmul(&inv_freq_short)?;
434 let short_sin = freqs_short.sin()?.mul(scaling_factor)?.to_dtype(dtype)?;
435 let short_cos = freqs_short.cos()?.mul(scaling_factor)?.to_dtype(dtype)?;
436
437 Ok(Self {
438 short_cos,
439 short_sin,
440 long_cos: Some(long_cos),
441 long_sin: Some(long_sin),
442 original_max_position_embeddings: cfg.original_max_position_embeddings,
443 })
444 }
445
446 fn new_unscaled(cfg: &PhiRopeConfig, dtype: DType, dev: &Device) -> Result<Self> {
447 let max_seq_len = cfg.max_position_embeddings;
448 let dim = (cfg.head_dim as f64 * cfg.partial_rotary_factor.unwrap_or(1.)) as usize;
449
450 let inv_freq: Vec<_> = (0..dim)
451 .step_by(2)
452 .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
453 .collect();
454 let inv_freq_len = inv_freq.len();
455 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
456 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
457 .to_dtype(DType::F32)?
458 .reshape((max_seq_len, 1))?;
459 let freqs = t.matmul(&inv_freq)?;
460 let sin = freqs.sin()?.to_dtype(dtype)?;
461 let cos = freqs.cos()?.to_dtype(dtype)?;
462 Ok(Self {
463 short_cos: cos,
464 short_sin: sin,
465 long_cos: None,
466 long_sin: None,
467 original_max_position_embeddings: cfg.original_max_position_embeddings,
468 })
469 }
470
471 #[allow(clippy::too_many_arguments)]
472 fn new_scaled(
473 short_factor: &[f64],
474 long_factor: &[f64],
475 scaling_type: &ScaledRopeType,
476 long_mscale: f64,
477 short_mscale: f64,
478 cfg: &PhiRopeConfig,
479 dtype: DType,
480 dev: &Device,
481 ) -> Result<Self> {
482 let max_seq_len = cfg.max_position_embeddings;
483 let dim = (cfg.head_dim as f64 * cfg.partial_rotary_factor.unwrap_or(1.)) as usize;
484
485 if !matches!(scaling_type, ScaledRopeType::Su) {
486 candle_core::bail!("Scaled Phi3 RoPE (non-classic scaled, with mscales) must have type `su`/`longrope`.");
487 }
488
489 if short_factor.len() != dim / 2 {
490 candle_core::bail!(
491 "Misaligned length {}, expected {} for `su`/`longrope` short rescale factors",
492 short_factor.len(),
493 dim / 2
494 );
495 }
496 if long_factor.len() != dim / 2 {
497 candle_core::bail!(
498 "Misaligned length {}, expected {} for `su`/`longrope` long rescale factors",
499 long_factor.len(),
500 dim / 2
501 );
502 }
503
504 let inv_freq_short: Vec<_> = (0..dim)
506 .step_by(2)
507 .enumerate()
508 .map(|(k, i)| {
509 1f32 / (short_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32
510 })
511 .collect();
512 let inv_freq_len_short = inv_freq_short.len();
513 let inv_freq_short = Tensor::from_vec(inv_freq_short, (1, inv_freq_len_short), dev)?;
514 let t_short = Tensor::arange(0u32, max_seq_len as u32, dev)?
515 .to_dtype(DType::F32)?
516 .reshape((max_seq_len, 1))?;
517 let freqs_short = t_short.matmul(&inv_freq_short)?;
518 let sin_short = (freqs_short.sin()?.to_dtype(dtype)? * short_mscale)?;
519 let cos_short = (freqs_short.cos()?.to_dtype(dtype)? * short_mscale)?;
520
521 let inv_freq_long: Vec<_> = (0..dim)
523 .step_by(2)
524 .enumerate()
525 .map(|(k, i)| {
526 1f32 / (long_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32
527 })
528 .collect();
529 let inv_freq_len_long = inv_freq_long.len();
530 let inv_freq_long = Tensor::from_vec(inv_freq_long, (1, inv_freq_len_long), dev)?;
531 let t_long = Tensor::arange(0u32, max_seq_len as u32, dev)?
532 .to_dtype(DType::F32)?
533 .reshape((max_seq_len, 1))?;
534 let freqs_long = t_long.matmul(&inv_freq_long)?;
535 let sin_long = (freqs_long.sin()?.to_dtype(dtype)? * long_mscale)?;
536 let cos_long = (freqs_long.cos()?.to_dtype(dtype)? * long_mscale)?;
537 Ok(Self {
538 short_cos: cos_short,
539 short_sin: sin_short,
540 long_cos: Some(cos_long),
541 long_sin: Some(sin_long),
542 original_max_position_embeddings: cfg.original_max_position_embeddings,
543 })
544 }
545
546 pub fn new(dtype: DType, cfg: impl Into<PhiRopeConfig>, dev: &Device) -> Result<Self> {
547 let cfg: PhiRopeConfig = cfg.into();
548
549 match &cfg.rope_scaling {
550 Some(PhiRopeScalingConfig::Classic {
551 short_factor,
552 long_factor,
553 scaling_type,
554 }) => {
555 Self::new_classic_scaled(short_factor, long_factor, scaling_type, &cfg, dtype, dev)
556 }
557
558 Some(PhiRopeScalingConfig::Scaled {
559 short_factor,
560 long_factor,
561 scaling_type,
562 long_mscale,
563 short_mscale,
564 }) => Self::new_scaled(
565 short_factor,
566 long_factor,
567 scaling_type,
568 *long_mscale,
569 *short_mscale,
570 &cfg,
571 dtype,
572 dev,
573 ),
574
575 None => Self::new_unscaled(&cfg, dtype, dev),
576 }
577 }
578
579 fn get_long_or_short_sin_cos(&self, position_ids: &[usize]) -> (&Tensor, &Tensor) {
581 if self.long_cos.is_none() {
582 return (&self.short_sin, &self.short_cos);
583 }
584 let seq_len = position_ids.iter().max().unwrap() + 1;
585 if seq_len > self.original_max_position_embeddings {
586 (
587 self.long_sin.as_ref().unwrap(),
588 self.long_cos.as_ref().unwrap(),
589 )
590 } else {
591 (&self.short_sin, &self.short_cos)
592 }
593 }
594
595 pub fn forward(
596 &self,
597 q: &Tensor,
598 k: &Tensor,
599 seqlen_offsets: &[usize],
600 position_ids: &[usize],
601 ) -> Result<(Tensor, Tensor)> {
602 let (sin, cos) = self.get_long_or_short_sin_cos(position_ids);
603 let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
604
605 let rot_dim = cos.dim(D::Minus1)? * 2;
606
607 if rot_dim != q.dim(D::Minus1)? {
609 let rot_dim = cos.dim(D::Minus1)? * 2;
610 let q_rot = q.narrow(D::Minus1, 0, rot_dim)?;
611 let q_pass = q.narrow(D::Minus1, rot_dim, q.dim(D::Minus1)? - rot_dim)?;
612 let k_rot = k.narrow(D::Minus1, 0, rot_dim)?;
613 let k_pass = k.narrow(D::Minus1, rot_dim, k.dim(D::Minus1)? - rot_dim)?;
614
615 let (q_rot, k_rot) = if seqlen_offsets.len() == 1 {
616 let cos = cos.narrow(0, seqlen_offsets[0], seq_len)?;
617 let sin = sin.narrow(0, seqlen_offsets[0], seq_len)?;
618 let q_embed = candle_nn::rotary_emb::rope(&q_rot.contiguous()?, &cos, &sin)?;
619 let k_embed = candle_nn::rotary_emb::rope(&k_rot.contiguous()?, &cos, &sin)?;
620 (q_embed, k_embed)
621 } else {
622 let mut q_embeds = Vec::new();
623 let mut k_embeds = Vec::new();
624 for (i, offset) in seqlen_offsets.iter().enumerate() {
625 let cos = cos.narrow(0, *offset, seq_len)?;
626 let sin = sin.narrow(0, *offset, seq_len)?;
627 let q_embed = candle_nn::rotary_emb::rope(
628 &q_rot.i(i)?.unsqueeze(0)?.contiguous()?,
629 &cos,
630 &sin,
631 )?;
632 let k_embed = candle_nn::rotary_emb::rope(
633 &k_rot.i(i)?.unsqueeze(0)?.contiguous()?,
634 &cos,
635 &sin,
636 )?;
637 q_embeds.push(q_embed);
638 k_embeds.push(k_embed);
639 }
640 let q_rot = Tensor::cat(&q_embeds, 0)?;
641 let k_rot = Tensor::cat(&k_embeds, 0)?;
642 (q_rot, k_rot)
643 };
644
645 Ok((
646 Tensor::cat(&[q_rot, q_pass], D::Minus1)?.contiguous()?,
647 Tensor::cat(&[k_rot, k_pass], D::Minus1)?.contiguous()?,
648 ))
649 } else if seqlen_offsets.len() == 1 {
650 let cos = cos.narrow(0, seqlen_offsets[0], seq_len)?;
651 let sin = sin.narrow(0, seqlen_offsets[0], seq_len)?;
652 let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
653 let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
654 Ok((q_embed, k_embed))
655 } else {
656 let mut q_embeds = Vec::new();
657 let mut k_embeds = Vec::new();
658 for (i, offset) in seqlen_offsets.iter().enumerate() {
659 let cos = cos.narrow(0, *offset, seq_len)?;
660 let sin = sin.narrow(0, *offset, seq_len)?;
661 let q_embed =
662 candle_nn::rotary_emb::rope(&q.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
663 let k_embed =
664 candle_nn::rotary_emb::rope(&k.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
665 q_embeds.push(q_embed);
666 k_embeds.push(k_embed);
667 }
668 Ok((Tensor::cat(&q_embeds, 0)?, Tensor::cat(&k_embeds, 0)?))
669 }
670 }
671}
672
673#[derive(Debug, Clone)]
675pub struct Llama3RotaryEmbedding(RotaryEmbedding);
676
677#[derive(Debug, Clone, Deserialize, Serialize, Default)]
678pub enum Llama3RopeType {
679 #[serde(rename = "llama3")]
680 Llama3,
681 #[serde(rename = "linear")]
682 Linear,
683 #[default]
684 #[serde(rename = "default")]
685 Default,
686}
687
688#[derive(Debug, Clone, Deserialize, Serialize, Default)]
689pub struct Llama3RopeConfig {
690 pub factor: f32,
691 pub low_freq_factor: Option<f32>,
692 pub high_freq_factor: Option<f32>,
693 pub original_max_position_embeddings: Option<usize>,
694 pub rope_type: Llama3RopeType,
695}
696
697fn calculate_default_inv_freq(cfg: &llama::Config) -> Vec<f32> {
698 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
699 (0..head_dim)
700 .step_by(2)
701 .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
702 .collect()
703}
704
705fn calculate_default_inv_freq_llama4(cfg: &llama4::TextConfig) -> Vec<f32> {
706 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
707 (0..head_dim)
708 .step_by(2)
709 .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
710 .collect()
711}
712
713impl Llama3RotaryEmbedding {
715 pub fn new_llama3(
716 dtype: DType,
717 cfg: &llama::Config,
718 dev: &Device,
719 is_gpt_neox: bool,
720 ) -> Result<Self> {
721 match &cfg.rope_scaling {
722 None
723 | Some(Llama3RopeConfig {
724 rope_type: Llama3RopeType::Default,
725 ..
726 }) => Ok(Self(RotaryEmbedding::new(
727 cfg.rope_theta,
728 cfg.hidden_size / cfg.num_attention_heads,
729 cfg.max_position_embeddings,
730 dev,
731 is_gpt_neox,
732 dtype,
733 )?)),
734 Some(Llama3RopeConfig {
735 rope_type: Llama3RopeType::Llama3,
736 factor,
737 low_freq_factor,
738 high_freq_factor,
739 original_max_position_embeddings,
740 }) => {
741 let low_freq_factor = low_freq_factor.context("low_freq_factor is required")?;
742 let high_freq_factor = high_freq_factor.context("high_freq_factor is required")?;
743 let original_max_position_embeddings = original_max_position_embeddings
744 .context("original_max_position_embeddings is required")?;
745
746 let low_freq_wavelen = original_max_position_embeddings as f32 / low_freq_factor;
747 let high_freq_wavelen = original_max_position_embeddings as f32 / high_freq_factor;
748
749 let inv_freq = calculate_default_inv_freq(cfg)
750 .into_iter()
751 .map(|freq| {
752 let wavelen = 2. * PI / freq;
753 if wavelen < high_freq_wavelen {
754 freq
755 } else if wavelen > low_freq_wavelen {
756 freq / *factor
757 } else {
758 let smooth = (original_max_position_embeddings as f32 / wavelen
759 - low_freq_factor)
760 / (high_freq_factor - low_freq_factor);
761 (1. - smooth) * freq / *factor + smooth * freq
762 }
763 })
764 .collect::<Vec<_>>();
765 let inv_freq_len = inv_freq.len();
766 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
767 let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
768 .to_dtype(DType::F32)?
769 .reshape((cfg.max_position_embeddings, 1))?;
770 let freqs = t.matmul(&inv_freq)?;
771 let sin = freqs.sin()?.to_dtype(dtype)?;
772 let cos = freqs.cos()?.to_dtype(dtype)?;
773 Ok(Self(RotaryEmbedding {
774 sin,
775 cos,
776 is_gpt_neox,
777 }))
778 }
779 Some(Llama3RopeConfig {
780 rope_type: Llama3RopeType::Linear,
781 factor,
782 ..
783 }) => {
784 let inv_freq_vec = calculate_default_inv_freq(cfg)
785 .into_iter()
786 .map(|freq| freq / *factor)
787 .collect::<Vec<_>>();
788 let inv_freq_len = inv_freq_vec.len();
789 let inv_freq = Tensor::from_vec(inv_freq_vec, (1, inv_freq_len), dev)?;
790 let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
791 .to_dtype(DType::F32)?
792 .reshape((cfg.max_position_embeddings, 1))?;
793 let freqs = t.matmul(&inv_freq)?;
794 let sin = freqs.sin()?.to_dtype(dtype)?;
795 let cos = freqs.cos()?.to_dtype(dtype)?;
796 Ok(Self(RotaryEmbedding {
797 sin,
798 cos,
799 is_gpt_neox,
800 }))
801 }
802 }
803 }
804
805 pub fn new_llama4(
806 dtype: DType,
807 cfg: &llama4::TextConfig,
808 dev: &Device,
809 is_gpt_neox: bool,
810 ) -> Result<Self> {
811 match &cfg.rope_scaling {
812 None
813 | Some(Llama3RopeConfig {
814 rope_type: Llama3RopeType::Default,
815 ..
816 }) => Ok(Self(RotaryEmbedding::new(
817 cfg.rope_theta,
818 cfg.hidden_size / cfg.num_attention_heads,
819 cfg.max_position_embeddings,
820 dev,
821 is_gpt_neox,
822 dtype,
823 )?)),
824 Some(Llama3RopeConfig {
825 rope_type: Llama3RopeType::Llama3,
826 factor,
827 low_freq_factor,
828 high_freq_factor,
829 original_max_position_embeddings,
830 }) => {
831 let low_freq_factor = low_freq_factor.context("low_freq_factor is required")?;
832 let high_freq_factor = high_freq_factor.context("high_freq_factor is required")?;
833 let original_max_position_embeddings = original_max_position_embeddings
834 .context("original_max_position_embeddings is required")?;
835
836 let low_freq_wavelen = original_max_position_embeddings as f32 / low_freq_factor;
837 let high_freq_wavelen = original_max_position_embeddings as f32 / high_freq_factor;
838
839 let inv_freq = calculate_default_inv_freq_llama4(cfg)
840 .into_iter()
841 .map(|freq| {
842 let wavelen = 2. * PI / freq;
843 if wavelen < high_freq_wavelen {
844 freq
845 } else if wavelen > low_freq_wavelen {
846 freq / *factor
847 } else {
848 let smooth = (original_max_position_embeddings as f32 / wavelen
849 - low_freq_factor)
850 / (high_freq_factor - low_freq_factor);
851 (1. - smooth) * freq / *factor + smooth * freq
852 }
853 })
854 .collect::<Vec<_>>();
855 let inv_freq_len = inv_freq.len();
856 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
857 let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
858 .to_dtype(DType::F32)?
859 .reshape((cfg.max_position_embeddings, 1))?;
860 let freqs = t.matmul(&inv_freq)?;
861 let sin = freqs.sin()?.to_dtype(dtype)?;
862 let cos = freqs.cos()?.to_dtype(dtype)?;
863 Ok(Self(RotaryEmbedding {
864 sin,
865 cos,
866 is_gpt_neox,
867 }))
868 }
869 Some(Llama3RopeConfig {
870 rope_type: Llama3RopeType::Linear,
871 factor,
872 ..
873 }) => {
874 let inv_freq_vec = calculate_default_inv_freq_llama4(cfg)
875 .into_iter()
876 .map(|freq| freq / *factor)
877 .collect::<Vec<_>>();
878 let inv_freq_len = inv_freq_vec.len();
879 let inv_freq = Tensor::from_vec(inv_freq_vec, (1, inv_freq_len), dev)?;
880 let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
881 .to_dtype(DType::F32)?
882 .reshape((cfg.max_position_embeddings, 1))?;
883 let freqs = t.matmul(&inv_freq)?;
884 let sin = freqs.sin()?.to_dtype(dtype)?;
885 let cos = freqs.cos()?.to_dtype(dtype)?;
886 Ok(Self(RotaryEmbedding {
887 sin,
888 cos,
889 is_gpt_neox,
890 }))
891 }
892 }
893 }
894
895 pub fn new_mllama3(
896 dtype: DType,
897 cfg: &MLlamaTextConfig,
898 dev: &Device,
899 is_gpt_neox: bool,
900 ) -> Result<Self> {
901 match &cfg.rope_scaling {
902 None
903 | Some(MLlamaRopeScaling {
904 rope_type: MLlamaRopeType::Default,
905 ..
906 }) => Ok(Self(RotaryEmbedding::new(
907 cfg.rope_theta,
908 cfg.hidden_size / cfg.num_attention_heads,
909 cfg.max_position_embeddings,
910 dev,
911 is_gpt_neox,
912 dtype,
913 )?)),
914 Some(MLlamaRopeScaling {
915 rope_type: MLlamaRopeType::Llama3,
916 original_max_position_embeddings,
917 factor,
918 attention_factor: _,
919 beta_fast: _,
920 beta_slow: _,
921 short_factor: _,
922 long_factor: _,
923 low_freq_factor,
924 high_freq_factor,
925 }) => {
926 let factor = factor.context("MLlama Llama3 RoPE needs `factor` parameter.")?;
927 let low_freq_factor = low_freq_factor
928 .context("MLlama Llama3 RoPE needs `low_freq_factor` parameter.")?;
929 let high_freq_factor = high_freq_factor
930 .context("MLlama Llama3 RoPE needs `high_freq_factor` parameter.")?;
931
932 let low_freq_wavelen = *original_max_position_embeddings as f32 / low_freq_factor;
933 let high_freq_wavelen = *original_max_position_embeddings as f32 / high_freq_factor;
934
935 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
936
937 let inv_freq = (0..head_dim)
938 .step_by(2)
939 .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
940 .map(|freq| {
941 let wavelen = 2. * PI / freq;
942 if wavelen < high_freq_wavelen {
943 freq
944 } else if wavelen > low_freq_wavelen {
945 freq / factor
946 } else {
947 let smooth = (*original_max_position_embeddings as f32 / wavelen
948 - low_freq_factor)
949 / (high_freq_factor - low_freq_factor);
950 (1. - smooth) * freq / factor + smooth * freq
951 }
952 })
953 .collect::<Vec<_>>();
954 let inv_freq_len = inv_freq.len();
955 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
956
957 let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
958 .to_dtype(DType::F32)?
959 .reshape((cfg.max_position_embeddings, 1))?;
960 let freqs = t.matmul(&inv_freq)?;
961 let sin = freqs.sin()?.to_dtype(dtype)?;
962 let cos = freqs.cos()?.to_dtype(dtype)?;
963 Ok(Self(RotaryEmbedding {
964 sin,
965 cos,
966 is_gpt_neox,
967 }))
968 }
969 Some(MLlamaRopeScaling {
970 rope_type: other, ..
971 }) => {
972 candle_core::bail!(
973 "MLlama doesn't support any other RoPE type than `llama3`, got {other:?}"
974 )
975 }
976 }
977 }
978
979 pub fn forward(
980 &self,
981 q: &Tensor,
982 k: &Tensor,
983 seqlen_offsets: &[usize],
984 ) -> Result<(Tensor, Tensor)> {
985 self.0.forward(q, k, seqlen_offsets)
986 }
987}
988
989#[derive(Debug, Clone)]
991pub struct SmolLm3RotaryEmbedding(RotaryEmbedding);
992
993#[derive(Debug, Clone, Deserialize, Serialize, Default)]
994pub enum SmolLm3RopeType {
995 #[serde(rename = "llama3")]
996 Llama3,
997 #[serde(rename = "linear")]
998 Linear,
999 #[default]
1000 #[serde(rename = "default")]
1001 Default,
1002}
1003
1004#[derive(Debug, Clone, Deserialize, Serialize, Default)]
1005pub struct SmolLm3RopeConfig {
1006 pub factor: f32,
1007 pub low_freq_factor: Option<f32>,
1008 pub high_freq_factor: Option<f32>,
1009 pub original_max_position_embeddings: Option<usize>,
1010 pub rope_type: SmolLm3RopeType,
1011}
1012
1013fn calculate_default_inv_freq_smollm3(cfg: &smollm3::Config) -> Vec<f32> {
1014 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
1015 (0..head_dim)
1016 .step_by(2)
1017 .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
1018 .collect()
1019}
1020
1021impl SmolLm3RotaryEmbedding {
1022 pub fn new_llama3(
1023 dtype: DType,
1024 cfg: &smollm3::Config,
1025 dev: &Device,
1026 is_gpt_neox: bool,
1027 ) -> Result<Self> {
1028 match &cfg.rope_scaling {
1029 None
1030 | Some(SmolLm3RopeConfig {
1031 rope_type: SmolLm3RopeType::Default,
1032 ..
1033 }) => Ok(Self(RotaryEmbedding::new(
1034 cfg.rope_theta,
1035 cfg.hidden_size / cfg.num_attention_heads,
1036 cfg.max_position_embeddings,
1037 dev,
1038 is_gpt_neox,
1039 dtype,
1040 )?)),
1041 Some(SmolLm3RopeConfig {
1042 rope_type: SmolLm3RopeType::Llama3,
1043 factor,
1044 low_freq_factor,
1045 high_freq_factor,
1046 original_max_position_embeddings,
1047 }) => {
1048 let low_freq_factor = low_freq_factor.context("low_freq_factor is required")?;
1049 let high_freq_factor = high_freq_factor.context("high_freq_factor is required")?;
1050 let original_max_position_embeddings = original_max_position_embeddings
1051 .context("original_max_position_embeddings is required")?;
1052
1053 let low_freq_wavelen = original_max_position_embeddings as f32 / low_freq_factor;
1054 let high_freq_wavelen = original_max_position_embeddings as f32 / high_freq_factor;
1055
1056 let inv_freq = calculate_default_inv_freq_smollm3(cfg)
1057 .into_iter()
1058 .map(|freq| {
1059 let wavelen = 2. * PI / freq;
1060 if wavelen < high_freq_wavelen {
1061 freq
1062 } else if wavelen > low_freq_wavelen {
1063 freq / *factor
1064 } else {
1065 let smooth = (original_max_position_embeddings as f32 / wavelen
1066 - low_freq_factor)
1067 / (high_freq_factor - low_freq_factor);
1068 (1. - smooth) * freq / *factor + smooth * freq
1069 }
1070 })
1071 .collect::<Vec<_>>();
1072 let inv_freq_len = inv_freq.len();
1073 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
1074 let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
1075 .to_dtype(DType::F32)?
1076 .reshape((cfg.max_position_embeddings, 1))?;
1077 let freqs = t.matmul(&inv_freq)?;
1078 let sin = freqs.sin()?.to_dtype(dtype)?;
1079 let cos = freqs.cos()?.to_dtype(dtype)?;
1080 Ok(Self(RotaryEmbedding {
1081 sin,
1082 cos,
1083 is_gpt_neox,
1084 }))
1085 }
1086 Some(SmolLm3RopeConfig {
1087 rope_type: SmolLm3RopeType::Linear,
1088 factor,
1089 ..
1090 }) => {
1091 let inv_freq_vec = calculate_default_inv_freq_smollm3(cfg)
1092 .into_iter()
1093 .map(|freq| freq / *factor)
1094 .collect::<Vec<_>>();
1095 let inv_freq_len = inv_freq_vec.len();
1096 let inv_freq = Tensor::from_vec(inv_freq_vec, (1, inv_freq_len), dev)?;
1097 let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
1098 .to_dtype(DType::F32)?
1099 .reshape((cfg.max_position_embeddings, 1))?;
1100 let freqs = t.matmul(&inv_freq)?;
1101 let sin = freqs.sin()?.to_dtype(dtype)?;
1102 let cos = freqs.cos()?.to_dtype(dtype)?;
1103 Ok(Self(RotaryEmbedding {
1104 sin,
1105 cos,
1106 is_gpt_neox,
1107 }))
1108 }
1109 }
1110 }
1111
1112 pub fn forward(
1113 &self,
1114 q: &Tensor,
1115 k: &Tensor,
1116 seqlen_offsets: &[usize],
1117 ) -> Result<(Tensor, Tensor)> {
1118 self.0.forward(q, k, seqlen_offsets)
1119 }
1120}
1121
1122#[derive(Debug, Clone)]
1124pub struct Qwen2VLRotaryEmbedding {
1125 inv_freq: Tensor,
1126 mrope_section: Vec<usize>,
1127}
1128
1129impl Qwen2VLRotaryEmbedding {
1130 pub fn new(
1131 base: f32,
1132 head_dim: usize,
1133 device: &Device,
1134 mrope_section: Vec<usize>,
1135 ) -> Result<Self> {
1136 let inv_freq: Vec<_> = (0..head_dim)
1137 .step_by(2)
1138 .map(|i| 1f32 / base.powf(i as f32 / head_dim as f32))
1139 .collect();
1140 let inv_freq_len = inv_freq.len();
1141 let inv_freq = Tensor::from_vec(inv_freq, (inv_freq_len,), device)?.to_dtype(DType::F32)?;
1142 Ok(Self {
1143 inv_freq,
1144 mrope_section,
1145 })
1146 }
1147
1148 pub fn compute_cos_sin(&self, position_ids: &Tensor, dtype: DType) -> Result<(Tensor, Tensor)> {
1150 let inv_freq_expanded =
1151 self.inv_freq
1152 .reshape((1, 1, (), 1))?
1153 .repeat((3, position_ids.dim(1)?, 1, 1))?;
1154 let position_ids_expanded = position_ids.unsqueeze(2)?;
1155 let freqs = inv_freq_expanded
1156 .matmul(&position_ids_expanded.to_dtype(inv_freq_expanded.dtype())?)?
1157 .transpose(2, 3)?;
1158 let cos = freqs.cos()?;
1159 let sin = freqs.sin()?;
1160
1161 let cos = Tensor::cat(
1162 &cos.split(&self.mrope_section, D::Minus1)?
1163 .into_iter()
1164 .enumerate()
1165 .map(|(i, m)| m.i(i % 3))
1166 .collect::<Result<Vec<_>>>()?,
1167 D::Minus1,
1168 )?
1169 .squeeze(0)?
1170 .to_dtype(dtype)?
1171 .contiguous()?;
1172 let sin = Tensor::cat(
1173 &sin.split(&self.mrope_section, D::Minus1)?
1174 .into_iter()
1175 .enumerate()
1176 .map(|(i, m)| m.i(i % 3))
1177 .collect::<Result<Vec<_>>>()?,
1178 D::Minus1,
1179 )?
1180 .squeeze(0)?
1181 .to_dtype(dtype)?
1182 .contiguous()?;
1183
1184 Ok((cos, sin))
1185 }
1186
1187 pub fn forward(
1189 &self,
1190 (cos, sin): &(Tensor, Tensor),
1191 q: &mut Tensor,
1192 k: &mut Tensor,
1193 ) -> Result<()> {
1194 *q = candle_nn::rotary_emb::rope(&q.contiguous()?, cos, sin)?;
1195 *k = candle_nn::rotary_emb::rope(&k.contiguous()?, cos, sin)?;
1196 Ok(())
1197 }
1198}
1199
1200#[derive(Debug, Clone)]
1201pub struct Qwen2_5VLRotaryEmbedding {
1202 inv_freq: Tensor,
1203 mrope_section: Vec<usize>,
1204}
1205
1206impl Qwen2_5VLRotaryEmbedding {
1207 pub fn new(
1208 base: f32,
1209 head_dim: usize,
1210 device: &Device,
1211 mrope_section: Vec<usize>,
1212 ) -> Result<Self> {
1213 let inv_freq: Vec<_> = (0..head_dim)
1214 .step_by(2)
1215 .map(|i| 1f32 / base.powf(i as f32 / head_dim as f32))
1216 .collect();
1217 let inv_freq_len = inv_freq.len();
1218 let inv_freq = Tensor::from_vec(inv_freq, (inv_freq_len,), device)?.to_dtype(DType::F32)?;
1219 Ok(Self {
1220 inv_freq,
1221 mrope_section,
1222 })
1223 }
1224
1225 pub fn compute_cos_sin(&self, position_ids: &Tensor, dtype: DType) -> Result<(Tensor, Tensor)> {
1227 let inv_freq_expanded =
1228 self.inv_freq
1229 .reshape((1, 1, (), 1))?
1230 .repeat((3, position_ids.dim(1)?, 1, 1))?;
1231 let position_ids_expanded = position_ids.unsqueeze(2)?;
1232 let freqs = inv_freq_expanded
1233 .matmul(&position_ids_expanded.to_dtype(inv_freq_expanded.dtype())?)?
1234 .transpose(2, 3)?;
1235 let cos = freqs.cos()?;
1236 let sin = freqs.sin()?;
1237
1238 let cos = Tensor::cat(
1239 &cos.split(&self.mrope_section, D::Minus1)?
1240 .into_iter()
1241 .enumerate()
1242 .map(|(i, m)| m.i(i % 3))
1243 .collect::<Result<Vec<_>>>()?,
1244 D::Minus1,
1245 )?
1246 .squeeze(0)?
1247 .to_dtype(dtype)?
1248 .contiguous()?;
1249 let sin = Tensor::cat(
1250 &sin.split(&self.mrope_section, D::Minus1)?
1251 .into_iter()
1252 .enumerate()
1253 .map(|(i, m)| m.i(i % 3))
1254 .collect::<Result<Vec<_>>>()?,
1255 D::Minus1,
1256 )?
1257 .squeeze(0)?
1258 .to_dtype(dtype)?
1259 .contiguous()?;
1260
1261 Ok((cos, sin))
1262 }
1263
1264 pub fn forward(
1265 &self,
1266 (cos, sin): &(Tensor, Tensor),
1267 q: &mut Tensor,
1268 k: &mut Tensor,
1269 ) -> Result<()> {
1270 *q = candle_nn::rotary_emb::rope(&q.contiguous()?, cos, sin)?;
1271 *k = candle_nn::rotary_emb::rope(&k.contiguous()?, cos, sin)?;
1272 Ok(())
1273 }
1274}
1275
1276#[derive(Debug, Clone)]
1277pub struct DeepSeekV2RotaryEmbedding {
1278 sin: Tensor,
1279 cos: Tensor,
1280}
1281
1282#[derive(Debug, Clone, Deserialize, Serialize)]
1283#[serde(untagged)]
1284pub enum DeepSeekV2RopeScaling {
1285 Yarn {
1286 original_max_position_embeddings: usize,
1287 beta_fast: f32,
1288 beta_slow: f32,
1289 mscale: f32,
1290 mscale_all_dim: f32,
1291 factor: f32,
1292 #[serde(rename = "type")]
1293 scaling_type: ScaledRopeType,
1294 },
1295 LinearOrDynamic {
1296 #[serde(rename = "type")]
1297 scaling_type: ScaledRopeType,
1298 factor: f64,
1299 },
1300}
1301
1302pub struct DeepSeekV2RopeConfig {
1303 pub rope_scaling: Option<DeepSeekV2RopeScaling>,
1304 pub max_position_embeddings: usize,
1305 pub rope_theta: f32,
1306 pub qk_rope_head_dim: usize,
1307}
1308
1309impl DeepSeekV2RotaryEmbedding {
1310 fn new_unscaled(cfg: &DeepSeekV2RopeConfig, dtype: DType, dev: &Device) -> Result<Self> {
1311 let max_seq_len = cfg.max_position_embeddings;
1312 let dim = cfg.qk_rope_head_dim;
1313
1314 let inv_freq: Vec<_> = (0..dim)
1315 .step_by(2)
1316 .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / dim as f32))
1317 .collect();
1318 let inv_freq_len = inv_freq.len();
1319 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
1320 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
1321 .to_dtype(DType::F32)?
1322 .reshape((max_seq_len, 1))?;
1323 let freqs = t.matmul(&inv_freq)?;
1324
1325 let sin = freqs.sin()?.to_dtype(dtype)?;
1326 let cos = freqs.cos()?.to_dtype(dtype)?;
1327
1328 Ok(Self { sin, cos })
1329 }
1330
1331 fn yarn_find_correction_dim(
1332 num_rot: f32,
1333 dim: usize,
1334 base: f32,
1335 max_position_embeddings: usize,
1336 ) -> f32 {
1337 (dim as f32 * (max_position_embeddings as f32 / (num_rot * 2. * PI)).ln())
1338 / (2. * base.ln())
1339 }
1340
1341 fn yarn_find_correction_range(
1342 low_rot: f32,
1343 high_rot: f32,
1344 dim: usize,
1345 base: f32,
1346 max_position_embeddings: usize,
1347 ) -> (f32, f32) {
1348 let low =
1349 Self::yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings).floor();
1350 let high =
1351 Self::yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings).ceil();
1352 (low.max(0.), high.min(dim as f32 - 1.))
1353 }
1354
1355 fn yarn_linear_ramp_mask(min: f32, mut max: f32, dim: usize, dev: &Device) -> Result<Tensor> {
1356 if min == max {
1357 max += 0.001;
1359 }
1360 let linear_func =
1361 ((Tensor::arange(0f32, dim as f32, dev)? - min as f64)? / (max as f64 - min as f64))?;
1362 linear_func.clamp(0., 1)
1363 }
1364
1365 pub(crate) fn yarn_get_mscale(scale: f32, mscale: f32) -> f32 {
1366 if scale <= 1. {
1367 return 1.;
1368 }
1369 0.1 * mscale * scale.ln() + 1.
1370 }
1371
1372 #[allow(clippy::too_many_arguments)]
1373 fn new_yarn(
1374 cfg: &DeepSeekV2RopeConfig,
1375 dtype: DType,
1376 dev: &Device,
1377 original_max_position_embeddings: usize,
1378 beta_fast: f32,
1379 beta_slow: f32,
1380 factor: f32,
1381 mscale: f32,
1382 mscale_all_dim: f32,
1383 ) -> Result<Self> {
1384 let freq_extra: Vec<_> = (0..cfg.qk_rope_head_dim)
1385 .step_by(2)
1386 .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / cfg.qk_rope_head_dim as f32))
1387 .collect();
1388 let freq_extra_len = freq_extra.len();
1389 let freq_extra = Tensor::from_vec(freq_extra, freq_extra_len, dev)?;
1390 let freq_inter: Vec<_> = (0..cfg.qk_rope_head_dim)
1391 .step_by(2)
1392 .map(|i| 1f32 / (factor * cfg.rope_theta.powf(i as f32 / cfg.qk_rope_head_dim as f32)))
1393 .collect();
1394 let freq_inter_len = freq_inter.len();
1395 let freq_inter = Tensor::from_vec(freq_inter, (1, freq_inter_len), dev)?;
1396
1397 let (low, high) = Self::yarn_find_correction_range(
1398 beta_fast,
1399 beta_slow,
1400 cfg.qk_rope_head_dim,
1401 cfg.rope_theta,
1402 original_max_position_embeddings,
1403 );
1404 let inv_freq_mask =
1405 (1. - Self::yarn_linear_ramp_mask(low, high, cfg.qk_rope_head_dim / 2, dev)?)?;
1406 let inv_freq = freq_inter
1407 .broadcast_mul(&(1. - &inv_freq_mask)?)?
1408 .broadcast_add(&freq_extra.broadcast_mul(&inv_freq_mask)?)?;
1409
1410 let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
1411 .to_dtype(DType::F32)?
1412 .reshape((cfg.max_position_embeddings, 1))?;
1413 let freqs = t.matmul(&inv_freq)?;
1414
1415 let mscale =
1416 Self::yarn_get_mscale(factor, mscale) / Self::yarn_get_mscale(factor, mscale_all_dim);
1417 let sin = (freqs.sin()? * mscale as f64)?.to_dtype(dtype)?;
1418 let cos = (freqs.cos()? * mscale as f64)?.to_dtype(dtype)?;
1419
1420 Ok(Self { sin, cos })
1421 }
1422
1423 pub fn new(cfg: &DeepSeekV2RopeConfig, dtype: DType, dev: &Device) -> Result<Self> {
1424 match &cfg.rope_scaling {
1425 Some(DeepSeekV2RopeScaling::LinearOrDynamic {
1426 scaling_type: _,
1427 factor: _,
1428 }) => candle_core::bail!("linear and dynamic rope are not implemented yet!"),
1429 Some(DeepSeekV2RopeScaling::Yarn {
1430 original_max_position_embeddings,
1431 beta_fast,
1432 beta_slow,
1433 factor,
1434 mscale,
1435 mscale_all_dim,
1436 scaling_type: _,
1437 }) => Self::new_yarn(
1438 cfg,
1439 dtype,
1440 dev,
1441 *original_max_position_embeddings,
1442 *beta_fast,
1443 *beta_slow,
1444 *factor,
1445 *mscale,
1446 *mscale_all_dim,
1447 ),
1448 None => Self::new_unscaled(cfg, dtype, dev),
1449 }
1450 }
1451
1452 pub fn forward(
1453 &self,
1454 q: &Tensor,
1455 k: &Tensor,
1456 seqlen_offsets: &[usize],
1457 ) -> Result<(Tensor, Tensor)> {
1458 let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
1459
1460 if seqlen_offsets.len() == 1 {
1461 let cos = self.cos.narrow(0, seqlen_offsets[0], seq_len)?;
1462 let sin = self.sin.narrow(0, seqlen_offsets[0], seq_len)?;
1463 let q_embed = candle_nn::rotary_emb::rope_i(&q.contiguous()?, &cos, &sin)?;
1464 let k_embed = candle_nn::rotary_emb::rope_i(&k.contiguous()?, &cos, &sin)?;
1465 Ok((q_embed, k_embed))
1466 } else {
1467 let mut q_embeds = Vec::new();
1468 let mut k_embeds = Vec::new();
1469 for (i, offset) in seqlen_offsets.iter().enumerate() {
1470 let cos = self.cos.narrow(0, *offset, seq_len)?;
1471 let sin = self.sin.narrow(0, *offset, seq_len)?;
1472 let q_embed = candle_nn::rotary_emb::rope_i(
1473 &q.i(i)?.unsqueeze(0)?.contiguous()?,
1474 &cos,
1475 &sin,
1476 )?;
1477 let k_embed = candle_nn::rotary_emb::rope_i(
1478 &k.i(i)?.unsqueeze(0)?.contiguous()?,
1479 &cos,
1480 &sin,
1481 )?;
1482 q_embeds.push(q_embed);
1483 k_embeds.push(k_embed);
1484 }
1485 Ok((Tensor::cat(&q_embeds, 0)?, Tensor::cat(&k_embeds, 0)?))
1486 }
1487 }
1488}
1489
1490#[derive(Debug, Clone)]
1491pub struct Phi4MMRotaryEmbedding {
1492 short_sin: Tensor,
1493 short_cos: Tensor,
1494 long_cos: Option<Tensor>,
1495 long_sin: Option<Tensor>,
1496 original_max_position_embeddings: usize,
1497}
1498
1499#[derive(Debug, Clone, Default, Deserialize, Serialize)]
1500#[serde(rename_all = "lowercase")]
1501pub enum Phi4MMScaledRopeType {
1502 #[serde(alias = "longrope")]
1503 LongRope,
1504 #[default]
1505 Default,
1506}
1507
1508#[derive(Debug, Clone, Deserialize, Serialize)]
1509pub struct Phi4MMRopeScalingConfig {
1510 short_factor: Option<Vec<f64>>,
1511 long_factor: Option<Vec<f64>>,
1512 #[serde(rename = "type")]
1513 scaling_type: Phi4MMScaledRopeType,
1514}
1515
1516impl Phi4MMRotaryEmbedding {
1517 fn new_unscaled(cfg: &Phi4MMConfig, dtype: DType, dev: &Device) -> Result<Self> {
1518 let max_seq_len = cfg.max_position_embeddings;
1519 let dim = (cfg.head_dim() as f64 * cfg.partial_rotary_factor) as usize;
1520
1521 let inv_freq: Vec<_> = (0..dim)
1522 .step_by(2)
1523 .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
1524 .collect();
1525 let inv_freq_len = inv_freq.len();
1526 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
1527 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
1528 .to_dtype(DType::F32)?
1529 .reshape((max_seq_len, 1))?;
1530 let freqs = t.matmul(&inv_freq)?;
1531 let sin = freqs.sin()?.to_dtype(dtype)?;
1532 let cos = freqs.cos()?.to_dtype(dtype)?;
1533 Ok(Self {
1534 short_cos: cos,
1535 short_sin: sin,
1536 long_cos: None,
1537 long_sin: None,
1538 original_max_position_embeddings: cfg.original_max_position_embeddings,
1539 })
1540 }
1541
1542 #[allow(clippy::too_many_arguments)]
1543 fn new_longrope(
1544 short_factor: &[f64],
1545 long_factor: &[f64],
1546 cfg: &Phi4MMConfig,
1547 dtype: DType,
1548 dev: &Device,
1549 ) -> Result<Self> {
1550 let max_seq_len = cfg.max_position_embeddings;
1551 let dim = (cfg.head_dim() as f64 * cfg.partial_rotary_factor) as usize;
1552
1553 let scale =
1555 cfg.max_position_embeddings as f64 / cfg.original_max_position_embeddings as f64;
1556 let scaling_factor = if scale <= 1.0 {
1557 1.0
1558 } else {
1559 (1.0 + scale.ln() / (cfg.original_max_position_embeddings as f64).ln()).sqrt()
1560 };
1561
1562 let inv_freq_short: Vec<_> = (0..dim)
1564 .step_by(2)
1565 .enumerate()
1566 .map(|(k, i)| {
1567 1f32 / (short_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32
1568 })
1569 .collect();
1570 let inv_freq_len_short = inv_freq_short.len();
1571 let inv_freq_short = Tensor::from_vec(inv_freq_short, (1, inv_freq_len_short), dev)?;
1572 let t_short = Tensor::arange(0u32, max_seq_len as u32, dev)?
1573 .to_dtype(DType::F32)?
1574 .reshape((max_seq_len, 1))?;
1575 let freqs_short = t_short.matmul(&inv_freq_short)?;
1576 let sin_short = (freqs_short.sin()?.to_dtype(dtype)? * scaling_factor)?;
1577 let cos_short = (freqs_short.cos()?.to_dtype(dtype)? * scaling_factor)?;
1578
1579 let inv_freq_long: Vec<_> = (0..dim)
1581 .step_by(2)
1582 .enumerate()
1583 .map(|(k, i)| {
1584 1f32 / (long_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32
1585 })
1586 .collect();
1587 let inv_freq_len_long = inv_freq_long.len();
1588 let inv_freq_long = Tensor::from_vec(inv_freq_long, (1, inv_freq_len_long), dev)?;
1589 let t_long = Tensor::arange(0u32, max_seq_len as u32, dev)?
1590 .to_dtype(DType::F32)?
1591 .reshape((max_seq_len, 1))?;
1592 let freqs_long = t_long.matmul(&inv_freq_long)?;
1593 let sin_long = (freqs_long.sin()?.to_dtype(dtype)? * scaling_factor)?;
1594 let cos_long = (freqs_long.cos()?.to_dtype(dtype)? * scaling_factor)?;
1595
1596 Ok(Self {
1597 short_cos: cos_short,
1598 short_sin: sin_short,
1599 long_cos: Some(cos_long),
1600 long_sin: Some(sin_long),
1601 original_max_position_embeddings: cfg.original_max_position_embeddings,
1602 })
1603 }
1604
1605 pub fn new(dtype: DType, cfg: &Phi4MMConfig, dev: &Device) -> Result<Self> {
1606 match &cfg.rope_scaling {
1607 Some(Phi4MMRopeScalingConfig {
1608 scaling_type: Phi4MMScaledRopeType::LongRope,
1609 short_factor: Some(short_factor),
1610 long_factor: Some(long_factor),
1611 }) => Self::new_longrope(short_factor, long_factor, cfg, dtype, dev),
1612
1613 _ => Self::new_unscaled(cfg, dtype, dev),
1614 }
1615 }
1616
1617 fn get_long_or_short_sin_cos(&self, position_ids: &[usize]) -> (&Tensor, &Tensor) {
1619 if self.long_cos.is_none() {
1620 return (&self.short_sin, &self.short_cos);
1621 }
1622 let seq_len = position_ids.iter().max().unwrap() + 1;
1623 if seq_len > self.original_max_position_embeddings {
1624 (
1625 self.long_sin.as_ref().unwrap(),
1626 self.long_cos.as_ref().unwrap(),
1627 )
1628 } else {
1629 (&self.short_sin, &self.short_cos)
1630 }
1631 }
1632
1633 pub fn forward(
1634 &self,
1635 q: &Tensor,
1636 k: &Tensor,
1637 seqlen_offsets: &[usize],
1638 position_ids: &[usize],
1639 ) -> Result<(Tensor, Tensor)> {
1640 let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
1641 let (sin, cos) = self.get_long_or_short_sin_cos(position_ids);
1642
1643 let rot_dim = cos.dim(D::Minus1)? * 2;
1644 let q_rot = q.narrow(D::Minus1, 0, rot_dim)?;
1645 let q_pass = q.narrow(D::Minus1, rot_dim, q.dim(D::Minus1)? - rot_dim)?;
1646 let k_rot = k.narrow(D::Minus1, 0, rot_dim)?;
1647 let k_pass = k.narrow(D::Minus1, rot_dim, k.dim(D::Minus1)? - rot_dim)?;
1648
1649 let (q_rot, k_rot) = if seqlen_offsets.len() == 1 {
1650 let cos = cos.narrow(0, seqlen_offsets[0], seq_len)?;
1651 let sin = sin.narrow(0, seqlen_offsets[0], seq_len)?;
1652 let q_embed = candle_nn::rotary_emb::rope(&q_rot.contiguous()?, &cos, &sin)?;
1653 let k_embed = candle_nn::rotary_emb::rope(&k_rot.contiguous()?, &cos, &sin)?;
1654 (q_embed, k_embed)
1655 } else {
1656 let mut q_embeds = Vec::new();
1657 let mut k_embeds = Vec::new();
1658 for (i, offset) in seqlen_offsets.iter().enumerate() {
1659 let cos = cos.narrow(0, *offset, seq_len)?;
1660 let sin = sin.narrow(0, *offset, seq_len)?;
1661 let q_embed = candle_nn::rotary_emb::rope(
1662 &q_rot.i(i)?.unsqueeze(0)?.contiguous()?,
1663 &cos,
1664 &sin,
1665 )?;
1666 let k_embed = candle_nn::rotary_emb::rope(
1667 &k_rot.i(i)?.unsqueeze(0)?.contiguous()?,
1668 &cos,
1669 &sin,
1670 )?;
1671 q_embeds.push(q_embed);
1672 k_embeds.push(k_embed);
1673 }
1674 let q_rot = Tensor::cat(&q_embeds, 0)?;
1675 let k_rot = Tensor::cat(&k_embeds, 0)?;
1676 (q_rot, k_rot)
1677 };
1678
1679 Ok((
1680 Tensor::cat(&[q_rot, q_pass], D::Minus1)?.contiguous()?,
1681 Tensor::cat(&[k_rot, k_pass], D::Minus1)?.contiguous()?,
1682 ))
1683 }
1684}
1685
1686#[derive(Debug, Clone)]
1687pub struct Gemma3nRotaryEmbedding(RotaryEmbedding);
1688
1689#[derive(Debug, Clone, Deserialize, Serialize)]
1690#[serde(rename_all = "lowercase")]
1691pub enum Gemma3nScaledRopeType {
1692 #[serde(alias = "linear")]
1693 Linear,
1694}
1695
1696#[derive(Debug, Clone, Deserialize, Serialize)]
1697pub struct Gemma3nRopeScalingConfig {
1698 factor: f64,
1699 rope_type: Gemma3nScaledRopeType,
1700}
1701
1702impl Gemma3nRotaryEmbedding {
1703 fn new_linear(
1704 cfg: &Gemma3nTextConfig,
1705 factor: f64,
1706 is_gpt_neox: bool,
1707 dtype: DType,
1708 dev: &Device,
1709 ) -> Result<Self> {
1710 let max_seq_len = cfg.max_position_embeddings;
1711 let dim = cfg.head_dim;
1712
1713 let inv_freq: Vec<_> = (0..dim)
1714 .step_by(2)
1715 .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
1716 .collect();
1717 let inv_freq_len = inv_freq.len();
1718 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
1719 let inv_freq = (inv_freq / factor)?;
1720
1721 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
1722 .to_dtype(DType::F32)?
1723 .reshape((max_seq_len, 1))?;
1724 let freqs = t.matmul(&inv_freq)?;
1725 let sin = freqs.sin()?.to_dtype(dtype)?;
1726 let cos = freqs.cos()?.to_dtype(dtype)?;
1727 Ok(Self(RotaryEmbedding {
1728 cos,
1729 sin,
1730 is_gpt_neox,
1731 }))
1732 }
1733
1734 pub fn new(
1735 is_gpt_neox: bool,
1736 dtype: DType,
1737 cfg: &Gemma3nTextConfig,
1738 dev: &Device,
1739 ) -> Result<Self> {
1740 match &cfg.rope_scaling {
1741 Some(Gemma3RopeScalingConfig {
1742 rope_type: Gemma3ScaledRopeType::Linear,
1743 factor,
1744 }) => Self::new_linear(cfg, *factor, is_gpt_neox, dtype, dev),
1745
1746 _ => Self::new_linear(cfg, 1.0, is_gpt_neox, dtype, dev),
1747 }
1748 }
1749
1750 pub fn get_cos_sin(&self) -> Result<(Tensor, Tensor)> {
1751 self.0.get_cos_sin()
1752 }
1753
1754 pub fn forward(
1755 &self,
1756 q: &Tensor,
1757 k: &Tensor,
1758 seqlen_offsets: &[usize],
1759 ) -> Result<(Tensor, Tensor)> {
1760 self.0.forward(q, k, seqlen_offsets)
1761 }
1762}
1763
1764#[derive(Debug, Clone)]
1765pub struct Gemma3RotaryEmbedding(RotaryEmbedding);
1766
1767#[derive(Debug, Clone, Deserialize, Serialize)]
1768#[serde(rename_all = "lowercase")]
1769pub enum Gemma3ScaledRopeType {
1770 #[serde(alias = "linear")]
1771 Linear,
1772}
1773
1774#[derive(Debug, Clone, Deserialize, Serialize)]
1775pub struct Gemma3RopeScalingConfig {
1776 factor: f64,
1777 rope_type: Gemma3ScaledRopeType,
1778}
1779
1780impl Gemma3RotaryEmbedding {
1781 fn new_linear(
1782 cfg: &Gemma3TextConfig,
1783 factor: f64,
1784 is_gpt_neox: bool,
1785 dtype: DType,
1786 dev: &Device,
1787 ) -> Result<Self> {
1788 let max_seq_len = cfg.max_position_embeddings;
1789 let dim = cfg.head_dim;
1790
1791 let inv_freq: Vec<_> = (0..dim)
1792 .step_by(2)
1793 .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
1794 .collect();
1795 let inv_freq_len = inv_freq.len();
1796 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
1797 let inv_freq = (inv_freq / factor)?;
1798
1799 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
1800 .to_dtype(DType::F32)?
1801 .reshape((max_seq_len, 1))?;
1802 let freqs = t.matmul(&inv_freq)?;
1803 let sin = freqs.sin()?.to_dtype(dtype)?;
1804 let cos = freqs.cos()?.to_dtype(dtype)?;
1805 Ok(Self(RotaryEmbedding {
1806 cos,
1807 sin,
1808 is_gpt_neox,
1809 }))
1810 }
1811
1812 pub fn new(
1813 is_gpt_neox: bool,
1814 dtype: DType,
1815 cfg: &Gemma3TextConfig,
1816 dev: &Device,
1817 ) -> Result<Self> {
1818 match &cfg.rope_scaling {
1819 Some(Gemma3RopeScalingConfig {
1820 rope_type: Gemma3ScaledRopeType::Linear,
1821 factor,
1822 }) => Self::new_linear(cfg, *factor, is_gpt_neox, dtype, dev),
1823
1824 _ => Self::new_linear(cfg, 1.0, is_gpt_neox, dtype, dev),
1825 }
1826 }
1827
1828 pub fn forward(
1829 &self,
1830 q: &Tensor,
1831 k: &Tensor,
1832 seqlen_offsets: &[usize],
1833 ) -> Result<(Tensor, Tensor)> {
1834 self.0.forward(q, k, seqlen_offsets)
1835 }
1836}
1837
1838pub struct DiaRotaryEmbedding {
1839 timescale: Tensor,
1840 dtype: DType,
1841}
1842
1843impl DiaRotaryEmbedding {
1844 pub fn new(
1845 min_timescale: f32,
1846 max_timescale: f32,
1847 head_dim: usize,
1848 device: &Device,
1849 dtype: DType,
1850 ) -> Result<Self> {
1851 assert_eq!(head_dim % 2, 0);
1852 let half_embedding_dim = head_dim / 2;
1853
1854 let fraction = (0..half_embedding_dim).map(|i| 2f32 * i as f32 / head_dim as f32);
1855 let timescale = fraction
1856 .into_iter()
1857 .map(|x| min_timescale * (max_timescale / min_timescale).powf(x))
1858 .collect::<Vec<_>>();
1859
1860 let timescale_len = timescale.len();
1861 let timescale = Tensor::from_vec(timescale, timescale_len, device)?;
1862
1863 Ok(Self { timescale, dtype })
1864 }
1865
1866 pub fn forward(&self, xs: &Tensor, positions: &Tensor) -> Result<Tensor> {
1867 let freqs = positions
1868 .unsqueeze(D::Minus1)?
1869 .unsqueeze(D::Minus1)?
1870 .broadcast_div(&self.timescale)?;
1871
1872 let sin = freqs.sin()?.to_dtype(self.dtype)?;
1873 let cos = freqs.cos()?.to_dtype(self.dtype)?;
1874
1875 let split = xs.chunk(2, D::Minus1)?;
1876 let first_half = &split[0];
1877 let second_half = &split[1];
1878
1879 let first_part = (first_half.broadcast_mul(&cos)? - second_half.broadcast_mul(&sin)?)?;
1880 let second_part = (second_half.broadcast_mul(&cos)? + first_half.broadcast_mul(&sin)?)?;
1881
1882 Tensor::cat(&[first_part, second_part], D::Minus1)
1883 }
1884}
1885#[derive(Debug, Clone)]
1886pub struct QLinear {
1887 inner: QMatMul,
1888 bias: Option<Tensor>,
1889 dtype: DType,
1890}
1891
1892impl QLinear {
1893 pub fn new<R: std::io::Read + std::io::Seek>(
1894 ct: &mut Content<'_, R>,
1895 name: &str,
1896 device: &Device,
1897 ) -> Result<Self> {
1898 let w = ct.tensor(&format!("{name}.weight"), device)?;
1899 let b = ct.tensor(&format!("{name}.bias"), device)?;
1900 let inner = QMatMul::from_qtensor(w)?;
1901 let bias = b.dequantize(device)?;
1902 Ok(Self {
1903 inner,
1904 bias: Some(bias),
1905 dtype: DType::F32,
1906 })
1907 }
1908
1909 pub fn from_linear(linear: Linear) -> Self {
1910 Self {
1911 inner: QMatMul::Tensor(linear.weight().clone()),
1912 bias: linear.bias().cloned(),
1913 dtype: linear.weight().dtype(),
1914 }
1915 }
1916
1917 pub fn from_parts(w: Tensor, b: Option<Tensor>) -> Self {
1918 let dtype = w.dtype();
1919 Self {
1920 inner: QMatMul::Tensor(w),
1921 bias: b,
1922 dtype,
1923 }
1924 }
1925
1926 pub fn from_qparts(w: QTensor, b: Option<Tensor>) -> Self {
1927 if let Some(ref b) = b {
1928 assert_eq!(b.dtype(), DType::F32);
1929 }
1930 Self {
1931 inner: QMatMul::QTensor(Arc::new(w)),
1932 bias: b,
1933 dtype: DType::F32,
1934 }
1935 }
1936
1937 pub fn from_old_and_qmatmul(inner: QMatMul, old: &Self) -> Self {
1938 Self {
1939 inner,
1940 bias: old.bias.clone(),
1941 dtype: old.dtype,
1942 }
1943 }
1944
1945 pub fn inner(&mut self) -> &mut QMatMul {
1946 &mut self.inner
1947 }
1948
1949 pub fn inner_ref(&self) -> &QMatMul {
1950 &self.inner
1951 }
1952
1953 pub fn is_quant(&self) -> bool {
1954 matches!(self.inner, QMatMul::QTensor(_))
1955 }
1956
1957 pub fn bias(&self) -> Option<&Tensor> {
1958 self.bias.as_ref()
1959 }
1960
1961 pub fn bias_mut(&mut self) -> Option<&mut Tensor> {
1962 self.bias.as_mut()
1963 }
1964}
1965
1966impl Module for QLinear {
1967 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
1968 let xs = if self.is_quant() {
1969 xs.to_dtype(DType::F32)?
1970 } else {
1971 xs.clone()
1972 };
1973 if let Some(bias) = &self.bias {
1974 self.inner
1975 .forward(&xs)?
1976 .broadcast_add(bias)?
1977 .to_dtype(self.dtype)
1978 } else {
1979 self.inner.forward(&xs)?.to_dtype(self.dtype)
1980 }
1981 }
1982}
1983
1984#[derive(Debug, Clone)]
1985pub struct RotaryEmbedding {
1986 cos: Tensor,
1987 sin: Tensor,
1988 is_gpt_neox: bool,
1989}
1990
1991impl RotaryEmbedding {
1992 pub fn new(
1993 base: f32,
1994 head_dim: usize,
1995 max_position_embeddings: usize,
1996 device: &Device,
1997 is_gpt_neox: bool,
1998 dtype: DType,
1999 ) -> Result<Self> {
2000 let inv_freq: Vec<_> = (0..head_dim)
2001 .step_by(2)
2002 .map(|i| 1f32 / base.powf(i as f32 / head_dim as f32))
2003 .collect();
2004 let inv_freq_len = inv_freq.len();
2005 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), device)?;
2006 let t = Tensor::arange(0u32, max_position_embeddings as u32, device)?
2007 .to_dtype(DType::F32)?
2008 .reshape((max_position_embeddings, 1))?;
2009 let freqs = t.matmul(&inv_freq)?;
2010 let sin = freqs.sin()?.to_dtype(dtype)?;
2011 let cos = freqs.cos()?.to_dtype(dtype)?;
2012
2013 Ok(Self {
2014 cos,
2015 sin,
2016 is_gpt_neox,
2017 })
2018 }
2019
2020 pub fn get_cos_sin(&self) -> Result<(Tensor, Tensor)> {
2021 Ok((self.cos.clone(), self.sin.clone()))
2022 }
2023
2024 pub fn new_partial(
2025 base: f32,
2026 rot_dim: usize,
2027 max_position_embeddings: usize,
2028 device: &Device,
2029 is_gpt_neox: bool,
2030 dtype: DType,
2031 ) -> Result<Self> {
2032 let inv_freq: Vec<_> = (0..rot_dim)
2033 .step_by(2)
2034 .map(|i| 1f32 / base.powf(i as f32 / rot_dim as f32))
2035 .collect();
2036 let inv_freq_len = inv_freq.len();
2037 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), device)?;
2038 let t = Tensor::arange(0u32, max_position_embeddings as u32, device)?
2039 .to_dtype(DType::F32)?
2040 .reshape((max_position_embeddings, 1))?;
2041 let freqs = t.matmul(&inv_freq)?;
2042 let sin = freqs.sin()?.to_dtype(dtype)?;
2043 let cos = freqs.cos()?.to_dtype(dtype)?;
2044
2045 Ok(Self {
2046 cos,
2047 sin,
2048 is_gpt_neox,
2049 })
2050 }
2051
2052 pub fn forward(
2053 &self,
2054 q: &Tensor,
2055 k: &Tensor,
2056 seqlen_offsets: &[usize],
2057 ) -> Result<(Tensor, Tensor)> {
2058 let (b_sz, qh, seq_len, n_embd) = q.dims4()?;
2059 let (_b_sz, kh, _seq_len, __n_embd) = k.dims4()?;
2060
2061 let rope = if self.is_gpt_neox {
2062 candle_nn::rotary_emb::rope
2063 } else {
2064 candle_nn::rotary_emb::rope_i
2065 };
2066
2067 if cfg!(feature = "cuda") && qh == kh {
2068 let (cos, sin) = if seqlen_offsets.len() == 1 {
2069 (
2070 self.cos.narrow(0, seqlen_offsets[0], seq_len)?,
2071 self.sin.narrow(0, seqlen_offsets[0], seq_len)?,
2072 )
2073 } else {
2074 let mut cos_s = Vec::new();
2075 let mut sin_s = Vec::new();
2076 for offset in seqlen_offsets {
2077 cos_s.push(self.cos.narrow(0, *offset, seq_len)?);
2078 sin_s.push(self.sin.narrow(0, *offset, seq_len)?);
2079 }
2080 (Tensor::cat(&cos_s, 0)?, Tensor::cat(&sin_s, 0)?)
2081 };
2082
2083 let q_embed = q.transpose(1, 2)?.flatten(0, 1)?;
2084 let k_embed = k.transpose(1, 2)?.flatten(0, 1)?;
2085 mistralrs_quant::rotary::apply_rotary_inplace(
2086 &q_embed,
2087 &k_embed,
2088 &cos,
2089 &sin,
2090 self.is_gpt_neox,
2091 )?;
2092 let mut q = q_embed
2093 .reshape((b_sz, seq_len, qh, n_embd))?
2094 .transpose(1, 2)?;
2095 let mut k = k_embed
2096 .reshape((b_sz, seq_len, kh, n_embd))?
2097 .transpose(1, 2)?;
2098 if !(cfg!(feature = "flash-attn") || cfg!(feature = "flash-attn-v3")) {
2099 q = q.contiguous()?;
2100 k = k.contiguous()?;
2101 }
2102 Ok((q, k))
2103 } else if seqlen_offsets.len() == 1 {
2104 let cos = self.cos.narrow(0, seqlen_offsets[0], seq_len)?;
2105 let sin = self.sin.narrow(0, seqlen_offsets[0], seq_len)?;
2106 let q_embed = rope(&q.contiguous()?, &cos, &sin)?;
2107 let k_embed = rope(&k.contiguous()?, &cos, &sin)?;
2108 Ok((q_embed, k_embed))
2109 } else {
2110 let mut q_embeds = Vec::new();
2111 let mut k_embeds = Vec::new();
2112 for (i, offset) in seqlen_offsets.iter().enumerate() {
2113 let cos = self.cos.narrow(0, *offset, seq_len)?;
2114 let sin = self.sin.narrow(0, *offset, seq_len)?;
2115 let q_embed = rope(&q.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
2116 let k_embed = rope(&k.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
2117 q_embeds.push(q_embed);
2118 k_embeds.push(k_embed);
2119 }
2120 Ok((Tensor::cat(&q_embeds, 0)?, Tensor::cat(&k_embeds, 0)?))
2121 }
2122 }
2123}
2124
2125#[derive(Debug, Clone, Copy, PartialEq, Deserialize, Serialize, Default)]
2126#[serde(rename_all = "lowercase")]
2127pub enum Activation {
2128 #[default]
2129 #[serde(alias = "gelu")]
2130 Gelu,
2131 #[serde(alias = "gelu_new")]
2132 NewGelu,
2133 Relu,
2134 Relu2,
2135 Relu6,
2136 Silu,
2137 Sigmoid,
2138 HardSigmoid,
2139 Swiglu,
2140 Swish,
2141 HardSwish,
2142 Elu(f64),
2143 LeakyRelu(f64),
2144 #[serde(alias = "gelu_pytorch_tanh")]
2145 GeluPytorchTanh,
2146 QuickGelu,
2147}
2148
2149impl Module for Activation {
2150 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2151 match self {
2152 Self::Gelu => xs.gelu_erf(),
2153 Self::NewGelu => xs.gelu(),
2155 Self::Relu => xs.relu(),
2156 Self::Relu2 => xs.relu()?.sqr(),
2157 Self::Relu6 => xs.clamp(0f32, 6f32),
2158 Self::Silu => xs.silu(),
2159 Self::Sigmoid => candle_nn::ops::sigmoid(xs),
2160 Self::HardSigmoid => candle_nn::ops::hard_sigmoid(xs),
2161 Self::Swiglu => candle_nn::ops::swiglu(xs),
2162 Self::Swish => xs * candle_nn::ops::sigmoid(xs)?,
2163 Self::HardSwish => xs * candle_nn::ops::hard_sigmoid(xs)?,
2164 &Self::Elu(alpha) => xs.elu(alpha),
2165 &Self::LeakyRelu(negative_slope) => candle_nn::ops::leaky_relu(xs, negative_slope),
2166 Self::GeluPytorchTanh => xs.gelu(),
2167 Self::QuickGelu => xs * candle_nn::ops::sigmoid(&(xs * 1.702f64)?),
2168 }
2169 }
2170}
2171
2172impl TryInto<candle_nn::Activation> for Activation {
2173 type Error = candle_core::Error;
2174
2175 fn try_into(self) -> Result<candle_nn::Activation> {
2176 match self {
2177 Self::Gelu => Ok(candle_nn::Activation::Gelu),
2178 Self::Relu => Ok(candle_nn::Activation::Relu),
2179 Self::Silu => Ok(candle_nn::Activation::Silu),
2180 Self::NewGelu => Ok(candle_nn::Activation::NewGelu),
2181 Self::Relu2 => Ok(candle_nn::Activation::Relu2),
2182 Self::Relu6 => Ok(candle_nn::Activation::Relu6),
2183 Self::Sigmoid => Ok(candle_nn::Activation::Sigmoid),
2184 Self::HardSigmoid => Ok(candle_nn::Activation::HardSigmoid),
2185 Self::Swiglu => Ok(candle_nn::Activation::Swiglu),
2186 Self::Swish => Ok(candle_nn::Activation::Swish),
2187 Self::HardSwish => Ok(candle_nn::Activation::HardSwish),
2188 Self::Elu(x) => Ok(candle_nn::Activation::Elu(x)),
2189 Self::LeakyRelu(x) => Ok(candle_nn::Activation::LeakyRelu(x)),
2190 Self::GeluPytorchTanh => Ok(candle_nn::Activation::GeluPytorchTanh),
2191 Self::QuickGelu => candle_core::bail!("No mapping to candle_nn for QuickGelu"),
2192 }
2193 }
2194}
2195
2196#[derive(Debug, Clone, Copy, PartialEq, Eq)]
2197pub struct Conv3dConfig {
2198 pub padding: usize,
2199 pub stride: usize,
2200 pub dilation: usize,
2201 pub groups: usize,
2202}
2203
2204impl Default for Conv3dConfig {
2205 fn default() -> Self {
2206 Self {
2207 padding: 0,
2208 stride: 1,
2209 dilation: 1,
2210 groups: 1,
2211 }
2212 }
2213}
2214
2215pub struct Conv3dNoBias {
2216 conv2d_1: Conv2d,
2217 conv2d_2: Conv2d,
2218}
2219
2220impl Conv3dNoBias {
2221 pub fn new(
2222 in_channels: usize,
2223 out_channels: usize,
2224 kernel_sizes: [usize; 3],
2225 cfg: Conv3dConfig,
2226 vb: ShardedVarBuilder,
2227 ) -> Result<Self> {
2228 let ws = vb.get(
2229 (
2230 out_channels,
2231 in_channels / cfg.groups,
2232 kernel_sizes[0],
2233 kernel_sizes[1],
2234 kernel_sizes[2],
2235 ),
2236 "weight",
2237 )?;
2238
2239 let w1 = ws.i((.., .., 0, .., ..))?;
2243 let w2 = ws.i((.., .., 1, .., ..))?;
2244
2245 let cfg = Conv2dConfig {
2246 padding: cfg.padding,
2247 stride: cfg.stride,
2248 dilation: cfg.dilation,
2249 groups: cfg.groups,
2250 };
2251
2252 Ok(Self {
2253 conv2d_1: Conv2d::new(w1.contiguous()?, None, cfg),
2254 conv2d_2: Conv2d::new(w2.contiguous()?, None, cfg),
2255 })
2256 }
2257}
2258
2259impl Module for Conv3dNoBias {
2260 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2261 let xs1 = xs.i((.., .., 0, .., ..))?;
2262 let xs2 = xs.i((.., .., 1, .., ..))?;
2263
2264 (self.conv2d_1.forward(&xs1)? + self.conv2d_2.forward(&xs2)?)?.unsqueeze(2)
2265 }
2266}
2267
2268pub trait TensorInfExtend {
2269 fn is_inf(&self) -> Result<Self>
2270 where
2271 Self: Sized;
2272 fn any(&self) -> Result<bool>;
2273}
2274
2275impl TensorInfExtend for Tensor {
2276 fn is_inf(&self) -> Result<Self> {
2277 self.broadcast_eq(&Tensor::new(f32::INFINITY, self.device())?.to_dtype(self.dtype())?)
2278 }
2279
2280 fn any(&self) -> Result<bool> {
2281 let sum = self.sum_all()?;
2282 match self.dtype() {
2283 DType::U8 => Ok(sum.to_scalar::<u8>()? == 0),
2284 DType::U32 => Ok(sum.to_scalar::<u32>()? == 0),
2285 DType::I16 => Ok(sum.to_scalar::<i16>()? == 0),
2286 DType::I32 => Ok(sum.to_scalar::<i32>()? == 0),
2287 DType::I64 => Ok(sum.to_scalar::<i64>()? == 0),
2288 DType::F16 => Ok(sum.to_scalar::<half::f16>()? == half::f16::from_f32_const(0.)),
2289 DType::BF16 => Ok(sum.to_scalar::<half::bf16>()? == half::bf16::from_f32_const(0.)),
2290 DType::F32 => Ok(sum.to_scalar::<f32>()? == 0.),
2291 DType::F64 => Ok(sum.to_scalar::<f64>()? == 0.),
2292 DType::F8E4M3 => Ok(sum.to_scalar::<F8E4M3>()? == F8E4M3::ZERO),
2293 }
2294 }
2295}
2296
2297pub fn clamp_for_f16(xs: &Tensor) -> Result<Tensor> {
2298 let mut max = match xs.dtype() {
2299 DType::U8 => u8::MAX as f32 - 1000.,
2300 DType::U32 => u32::MAX as f32 - 1000.,
2301 DType::I16 => i16::MAX as f32 - 1000.,
2302 DType::I32 => i32::MAX as f32 - 1000.,
2303 DType::I64 => i64::MAX as f32 - 1000.,
2304 DType::F16 => half::f16::MAX.to_f32_const() - 1000.,
2305 DType::BF16 => half::bf16::MAX.to_f32_const() - 1000.,
2306 DType::F32 => f32::MAX - 1000.,
2307 DType::F64 => f64::MAX as f32 - 1000.,
2308 DType::F8E4M3 => F8E4M3::MAX.to_f32() - 1000.,
2309 };
2310 if xs.is_inf()?.any()? {
2311 max -= 1000.;
2312 }
2313 xs.clamp(-max, max)
2314}
2315
2316pub struct FloatInfo {
2317 pub min: f64,
2319 pub max: f64,
2321 pub eps: f64,
2323 pub dtype: DType,
2324}
2325
2326pub trait GetFloatInfo {
2327 fn finfo(&self) -> Result<FloatInfo>;
2328}
2329
2330impl GetFloatInfo for DType {
2331 fn finfo(&self) -> Result<FloatInfo> {
2332 let finfo = match self {
2333 Self::BF16 => FloatInfo {
2334 min: bf16::MIN.to_f64(),
2335 max: bf16::MAX.to_f64(),
2336 eps: bf16::EPSILON.to_f64(),
2337 dtype: DType::BF16,
2338 },
2339 Self::F16 => FloatInfo {
2340 min: f16::MIN.to_f64(),
2341 max: f16::MAX.to_f64(),
2342 eps: f16::EPSILON.to_f64(),
2343 dtype: DType::F16,
2344 },
2345 Self::F32 => FloatInfo {
2346 min: f32::MIN as f64,
2347 max: f32::MAX as f64,
2348 eps: f32::EPSILON as f64,
2349 dtype: DType::F32,
2350 },
2351 Self::F64 => FloatInfo {
2352 min: f64::MIN,
2353 max: f64::MAX,
2354 eps: f64::EPSILON,
2355 dtype: DType::F64,
2356 },
2357 Self::F8E4M3 => FloatInfo {
2358 min: F8E4M3::MIN.to_f64(),
2359 max: F8E4M3::MAX.to_f64(),
2360 eps: F8E4M3::EPSILON.to_f64(),
2361 dtype: DType::F8E4M3,
2362 },
2363 other => {
2364 candle_core::bail!("Expected a float type for `GetFloatInfo`, got {other:?}");
2365 }
2366 };
2367 Ok(finfo)
2368 }
2369}
2370
2371#[derive(Clone)]
2372pub struct Mlp {
2373 pub gate: Arc<dyn QuantMethod>,
2374 pub up: Arc<dyn QuantMethod>,
2375 pub down: Arc<dyn QuantMethod>,
2376 act: Activation,
2377 params: Vec<usize>,
2378}
2379
2380impl Mlp {
2381 pub fn new(
2382 vb: ShardedVarBuilder,
2383 hidden_size: usize,
2384 intermediate_size: usize,
2385 quantization_config: &Option<QuantizedConfig>,
2386 hidden_act: Activation,
2387 comm: &Arc<mistralrs_quant::Comm>,
2388 ) -> Result<Self> {
2389 Ok(Self {
2390 gate: ColumnParallelLayer::new(
2391 hidden_size,
2392 intermediate_size,
2393 quantization_config,
2394 false,
2395 comm,
2396 vb.pp("gate_proj"),
2397 )?,
2398 up: ColumnParallelLayer::new(
2399 hidden_size,
2400 intermediate_size,
2401 quantization_config,
2402 false,
2403 comm,
2404 vb.pp("up_proj"),
2405 )?,
2406 down: RowParallelLayer::new(
2407 intermediate_size,
2408 hidden_size,
2409 quantization_config,
2410 false,
2411 comm,
2412 vb.pp("down_proj"),
2413 )?,
2414 act: hidden_act,
2415 params: vec![hidden_size, intermediate_size],
2416 })
2417 }
2418
2419 pub fn new_merged(
2420 vb: ShardedVarBuilder,
2421 hidden_size: usize,
2422 intermediate_size: usize,
2423 chunks: usize,
2424 quantization_config: &Option<QuantizedConfig>,
2425 hidden_act: Activation,
2426 comm: &Arc<mistralrs_quant::Comm>,
2427 ) -> Result<Self> {
2428 assert!(chunks == 2, "Only gate_up_proj merge is supported!");
2429 let gate_up_projs = ColumnParallelLayer::new_merged(
2430 hidden_size,
2431 intermediate_size * 2,
2432 2,
2433 quantization_config,
2434 false,
2435 comm,
2436 vb.pp("gate_up_proj"),
2437 )?;
2438
2439 Ok(Self {
2440 gate: gate_up_projs[0].to_owned(),
2441 up: gate_up_projs[1].to_owned(),
2442 down: RowParallelLayer::new(
2443 intermediate_size,
2444 hidden_size,
2445 quantization_config,
2446 false,
2447 comm,
2448 vb.pp("down_proj"),
2449 )?,
2450 act: hidden_act,
2451 params: vec![hidden_size, intermediate_size],
2452 })
2453 }
2454
2455 pub fn replicate(
2456 params: &[usize],
2457 vb: ShardedVarBuilder,
2458 act: Activation,
2459 comm: &Arc<mistralrs_quant::Comm>,
2460 ) -> Result<Self> {
2461 Self::new(vb, params[0], params[1], &None, act, comm)
2462 }
2463
2464 pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2465 let original_dtype = xs.dtype();
2466 let mut xs = xs.clone();
2467 if let Some(t) = self.gate.quantized_act_type() {
2468 xs = xs.to_dtype(t)?;
2469 }
2470 let lhs = self.gate.forward(&xs)?;
2471 let rhs = self.up.forward(&xs)?;
2472 let mut res = self.down.forward(&candle_nn::ops::mul_and_act(
2473 &lhs,
2474 &rhs,
2475 self.act.try_into()?,
2476 )?)?;
2477 if self.gate.quantized_act_type().is_some() {
2478 res = res.to_dtype(original_dtype)?;
2479 }
2480 Ok(res)
2481 }
2482}
2483
2484impl AnyMoeTrainableLayer for Mlp {}
2485
2486impl MlpLayer for Mlp {
2487 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2488 let original_dtype = xs.dtype();
2489 let mut xs = xs.clone();
2490 if let Some(t) = self.gate.quantized_act_type() {
2491 xs = xs.to_dtype(t)?;
2492 }
2493 let lhs = MatMul.qmethod_matmul(&xs, &*self.gate)?;
2494 let rhs = MatMul.qmethod_matmul(&xs, &*self.up)?;
2495 let mut res = if matches!(
2496 self.act,
2497 Activation::Gelu | Activation::Silu | Activation::Relu
2498 ) {
2499 MatMul.qmethod_matmul(
2500 &candle_nn::ops::mul_and_act(&lhs, &rhs, self.act.try_into()?)?,
2501 &*self.down,
2502 )?
2503 } else {
2504 MatMul.qmethod_matmul(&(&lhs.apply(&self.act)? * &rhs)?, &*self.down)?
2505 };
2506 if self.gate.quantized_act_type().is_some() {
2507 res = res.to_dtype(original_dtype)?;
2508 }
2509 Ok(res)
2510 }
2511 fn get_isq_layers(&mut self) -> Vec<&mut Arc<dyn QuantMethod>> {
2512 vec![&mut self.gate, &mut self.up, &mut self.down]
2513 }
2514 fn clone(&self) -> Box<dyn MlpLayer> {
2515 Box::new(Clone::clone(self))
2516 }
2517 fn get_params(&self) -> &[usize] {
2518 &self.params
2519 }
2520 fn hidden_act(&self) -> Activation {
2521 self.act
2522 }
2523 fn new_added_delta(&self, deltas: Vec<Option<Tensor>>) -> Result<Box<dyn MlpLayer>> {
2525 let gate = if let Some(ref delta) = deltas[0] {
2526 self.gate.add_delta_w(delta)?
2527 } else {
2528 self.gate.clone()
2529 };
2530 let up = if let Some(ref delta) = deltas[1] {
2531 self.up.add_delta_w(delta)?
2532 } else {
2533 self.up.clone()
2534 };
2535 let down = if let Some(ref delta) = deltas[2] {
2536 self.down.add_delta_w(delta)?
2537 } else {
2538 self.down.clone()
2539 };
2540
2541 Ok(Box::new(Self {
2542 gate,
2543 up,
2544 down,
2545 act: self.act,
2546 params: self.params.clone(),
2547 }))
2548 }
2549
2550 fn dtype_device(&self) -> (DType, Device) {
2551 self.gate.dtype_and_device()
2552 }
2553}
2554
2555pub struct AvgPool2d {
2556 kernel_size: usize,
2557 stride: usize,
2558}
2559
2560impl AvgPool2d {
2561 pub fn new(kernel_size: usize, stride: usize) -> Self {
2562 Self {
2563 kernel_size,
2564 stride,
2565 }
2566 }
2567
2568 pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2569 xs.avg_pool2d_with_stride(self.kernel_size, self.stride)
2570 }
2571}
2572
2573pub struct ReflectionPad2d {
2580 padding: (usize, usize, usize, usize),
2581}
2582
2583impl ReflectionPad2d {
2584 pub fn new(padding: (usize, usize, usize, usize)) -> Self {
2585 Self { padding }
2586 }
2587}
2588
2589impl Module for ReflectionPad2d {
2590 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2591 let (pad_left, pad_right, pad_top, pad_bottom) = self.padding;
2592
2593 let (_n, _c, h, w) = xs.dims4()?;
2594
2595 let left_pad = if pad_left > 0 {
2598 let indices: Vec<i64> = (1..=pad_left as i64).rev().collect();
2600 Some(xs.index_select(&Tensor::new(indices, &Device::Cpu)?, 3)?)
2601 } else {
2602 None
2603 };
2604
2605 let right_pad = if pad_right > 0 {
2607 let start = w as i64 - 2;
2609 let indices: Vec<i64> = (0..pad_right as i64).map(|i| start - i).collect();
2610 Some(xs.index_select(&Tensor::new(indices, &Device::Cpu)?, 3)?)
2611 } else {
2612 None
2613 };
2614
2615 let x_padded_width = match (left_pad, right_pad) {
2617 (Some(l), Some(r)) => Tensor::cat(&[l, xs.clone(), r], 3)?,
2618 (Some(l), None) => Tensor::cat(&[l, xs.clone()], 3)?,
2619 (None, Some(r)) => Tensor::cat(&[xs.clone(), r], 3)?,
2620 (None, None) => xs.clone(),
2621 };
2622
2623 let top_pad = if pad_top > 0 {
2626 let indices: Vec<i64> = (1..=pad_top as i64).rev().collect();
2627 Some(x_padded_width.index_select(&Tensor::new(indices, &Device::Cpu)?, 2)?)
2628 } else {
2629 None
2630 };
2631
2632 let bottom_pad = if pad_bottom > 0 {
2634 let start = h as i64 - 2;
2635 let indices: Vec<i64> = (0..pad_bottom as i64).map(|i| start - i).collect();
2636 Some(x_padded_width.index_select(&Tensor::new(indices, &Device::Cpu)?, 2)?)
2637 } else {
2638 None
2639 };
2640
2641 let x_padded = match (top_pad, bottom_pad) {
2643 (Some(t), Some(b)) => Tensor::cat(&[t, x_padded_width, b], 2)?,
2644 (Some(t), None) => Tensor::cat(&[t, x_padded_width], 2)?,
2645 (None, Some(b)) => Tensor::cat(&[x_padded_width, b], 2)?,
2646 (None, None) => x_padded_width,
2647 };
2648
2649 Ok(x_padded)
2650 }
2651}
2652
2653pub struct ScaledEmbedding {
2654 scale: f64,
2655 pub embedding: Tensor,
2656}
2657
2658impl ScaledEmbedding {
2659 pub fn new(scale: f64, embedding: Embedding) -> Self {
2660 Self {
2661 scale,
2662 embedding: embedding.embeddings().clone(),
2663 }
2664 }
2665
2666 pub fn embeddings(&self) -> &Tensor {
2667 &self.embedding
2668 }
2669}
2670
2671impl Module for ScaledEmbedding {
2672 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2673 let embedding = Embedding::new(self.embedding.clone(), self.embedding.dim(D::Minus1)?);
2674 xs.apply(&embedding)? * self.scale
2675 }
2676}