mistralrs_core/
layers.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{f32::consts::PI, ops::Mul, str::FromStr, sync::Arc};
4
5use candle_core::{
6    quantized::{QMatMul, QTensor},
7    Context, DType, Device, IndexOp, Result, Tensor, D,
8};
9use candle_nn::{
10    BatchNorm, BatchNormConfig, Conv1d, Conv1dConfig, Conv2d, Conv2dConfig, Embedding, GroupNorm,
11    LayerNorm, LayerNormConfig, Linear, Module,
12};
13use float8::F8E4M3;
14use half::{bf16, f16};
15use mistralrs_quant::{
16    AfqLayer, ColumnParallelLayer, QuantMethod, QuantizedConfig, RowParallelLayer,
17    ShardedVarBuilder,
18};
19use serde::{Deserialize, Serialize};
20
21pub use crate::attention::Sdpa;
22pub use crate::layers_masker::CausalMasker;
23pub use crate::layers_utils::repeat_kv;
24use crate::{
25    amoe::{AnyMoeTrainableLayer, MlpLayer},
26    gguf::Content,
27    models::{llama, smollm3},
28    ops::SplitOp,
29    vision_models::{
30        gemma3::config::Gemma3TextConfig,
31        gemma3n::config::Gemma3nTextConfig,
32        llama4,
33        mllama::{MLlamaRopeScaling, MLlamaRopeType, MLlamaTextConfig},
34        phi4::Phi4MMConfig,
35    },
36};
37
38pub use mistralrs_quant::MatMul;
39
40pub fn embedding(
41    in_size: usize,
42    out_size: usize,
43    vb: ShardedVarBuilder,
44    config: &Option<QuantizedConfig>,
45) -> Result<Embedding> {
46    // AFQ quantized applies quantization to the embeddings.
47    let embeddings = if let Some(QuantizedConfig::Afq { .. }) = config {
48        let afq_layer =
49            AfqLayer::afq_linear_b(out_size, in_size, config.as_ref().unwrap(), false, vb)?;
50        afq_layer.dequantize_w()?
51    } else {
52        vb.get_with_hints((in_size, out_size), "weight", Default::default())?
53    };
54    Ok(Embedding::new(embeddings, out_size))
55}
56
57pub fn layer_norm<C: Into<LayerNormConfig>>(
58    size: usize,
59    config: C,
60    vb: ShardedVarBuilder,
61) -> Result<LayerNorm> {
62    let config = config.into();
63    let weight = vb.get(size, "weight")?;
64    if config.affine {
65        let bias = vb.get(size, "bias")?;
66        Ok(LayerNorm::new(weight, bias, config.eps))
67    } else {
68        Ok(LayerNorm::new_no_bias(weight, config.eps))
69    }
70}
71
72pub fn batch_norm<C: Into<BatchNormConfig>>(
73    num_features: usize,
74    config: C,
75    vb: ShardedVarBuilder,
76) -> Result<BatchNorm> {
77    let config = config.into();
78    if config.eps < 0. {
79        candle_core::bail!("batch-norm eps cannot be negative {}", config.eps)
80    }
81    let running_mean = vb.get(num_features, "running_mean")?;
82    let running_var = vb.get(num_features, "running_var")?;
83
84    if config.affine {
85        let weight = vb.get(num_features, "weight")?;
86        let bias = vb.get(num_features, "bias")?;
87        BatchNorm::new(
88            num_features,
89            running_mean,
90            running_var,
91            weight,
92            bias,
93            config.eps,
94        )
95    } else {
96        BatchNorm::new_no_bias(num_features, running_mean, running_var, config.eps)
97    }
98}
99
100pub fn group_norm(
101    num_groups: usize,
102    num_channels: usize,
103    eps: f64,
104    vb: ShardedVarBuilder,
105) -> Result<GroupNorm> {
106    let weight = vb.get(num_channels, "weight")?;
107    let bias = vb.get(num_channels, "bias")?;
108    GroupNorm::new(weight, bias, num_channels, num_groups, eps)
109}
110
111pub fn conv2d(
112    in_channels: usize,
113    out_channels: usize,
114    kernel_size: usize,
115    cfg: Conv2dConfig,
116    vb: ShardedVarBuilder,
117) -> Result<Conv2d> {
118    let ws = vb.get(
119        (
120            out_channels,
121            in_channels / cfg.groups,
122            kernel_size,
123            kernel_size,
124        ),
125        "weight",
126    )?;
127    let bs = vb.get(out_channels, "bias")?;
128    Ok(Conv2d::new(ws, Some(bs), cfg))
129}
130
131pub fn conv2d_no_bias(
132    in_channels: usize,
133    out_channels: usize,
134    kernel_size: usize,
135    cfg: Conv2dConfig,
136    vb: ShardedVarBuilder,
137) -> Result<Conv2d> {
138    let ws = vb.get(
139        (
140            out_channels,
141            in_channels / cfg.groups,
142            kernel_size,
143            kernel_size,
144        ),
145        "weight",
146    )?;
147    Ok(Conv2d::new(ws, None, cfg))
148}
149
150pub fn conv1d(
151    in_channels: usize,
152    out_channels: usize,
153    kernel_size: usize,
154    cfg: Conv1dConfig,
155    vb: ShardedVarBuilder,
156) -> Result<Conv1d> {
157    let ws = vb.get(
158        (out_channels, in_channels / cfg.groups, kernel_size),
159        "weight",
160    )?;
161    let bs = vb.get(out_channels, "bias")?;
162    Ok(Conv1d::new(ws, Some(bs), cfg))
163}
164
165pub fn conv1d_no_bias(
166    in_channels: usize,
167    out_channels: usize,
168    kernel_size: usize,
169    cfg: Conv1dConfig,
170    vb: ShardedVarBuilder,
171) -> Result<Conv1d> {
172    let ws = vb.get(
173        (out_channels, in_channels / cfg.groups, kernel_size),
174        "weight",
175    )?;
176    Ok(Conv1d::new(ws, None, cfg))
177}
178
179pub fn linear(in_dim: usize, out_dim: usize, vb: ShardedVarBuilder) -> Result<Linear> {
180    let ws = vb.get((out_dim, in_dim), "weight")?;
181    let bs = vb.get(out_dim, "bias")?;
182    Ok(Linear::new(ws, Some(bs)))
183}
184
185pub fn linear_no_bias(in_dim: usize, out_dim: usize, vb: ShardedVarBuilder) -> Result<Linear> {
186    let ws = vb.get((out_dim, in_dim), "weight")?;
187    Ok(Linear::new(ws, None))
188}
189
190pub fn linear_b(
191    in_dim: usize,
192    out_dim: usize,
193    bias: bool,
194    vb: ShardedVarBuilder,
195) -> Result<Linear> {
196    if bias {
197        linear(in_dim, out_dim, vb)
198    } else {
199        linear_no_bias(in_dim, out_dim, vb)
200    }
201}
202
203#[derive(Debug, Clone)]
204pub struct RmsNorm {
205    eps: f64,
206    weight: Tensor,
207}
208
209impl RmsNorm {
210    pub fn new(size: usize, eps: f64, vb: ShardedVarBuilder) -> Result<Self> {
211        let w = vb.get(size, "weight")?;
212        Ok(Self { eps, weight: w })
213    }
214
215    /// Gemma uses weight + 1.0
216    pub fn new_gemma(size: usize, eps: f64, vb: ShardedVarBuilder) -> Result<Self> {
217        let w = vb.get(size, "weight")?;
218        let w = (w + 1.0)?;
219        Ok(Self { eps, weight: w })
220    }
221
222    /// Gemma 3n uses weight
223    pub fn new_gemma_3n(
224        size: usize,
225        eps: f64,
226        with_scale: bool,
227        vb: ShardedVarBuilder,
228    ) -> Result<Self> {
229        let w = if with_scale {
230            vb.get(size, "weight")?
231        } else {
232            Tensor::ones(size, vb.dtype(), vb.device())?
233        };
234        Ok(Self { eps, weight: w })
235    }
236
237    /// Gemma uses weight + 1.0. Undo for UQFF generation.
238    pub fn undo_gemma(&self) -> Result<Self> {
239        Ok(Self {
240            eps: self.eps,
241            weight: (&self.weight - 1.0)?,
242        })
243    }
244
245    pub fn from_w(w: Tensor, eps: f64) -> Result<Self> {
246        Ok(Self { eps, weight: w })
247    }
248
249    pub fn weight(&self) -> &Tensor {
250        &self.weight
251    }
252}
253
254impl Module for RmsNorm {
255    fn forward(&self, x: &Tensor) -> Result<Tensor> {
256        candle_nn::ops::rms_norm(&x.contiguous()?, &self.weight, self.eps as f32)
257    }
258}
259
260#[derive(Debug, Clone)]
261pub struct F32RmsNorm {
262    w: Tensor,
263    eps: f64,
264}
265
266impl F32RmsNorm {
267    pub fn new(size: usize, eps: f64, vb: ShardedVarBuilder) -> Result<Self> {
268        Ok(Self {
269            w: vb.get((size,), "weight")?,
270            eps,
271        })
272    }
273
274    pub fn weight(&self) -> &Tensor {
275        &self.w
276    }
277}
278
279impl Module for F32RmsNorm {
280    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
281        let initial_type = xs.dtype();
282        let mut xs = xs.to_dtype(DType::F32)?;
283        let var = xs.powf(2.)?.mean_keepdim(D::Minus1)?;
284        xs = xs.broadcast_mul(&(&var + self.eps)?.recip()?.sqrt()?)?;
285        xs.to_dtype(initial_type)?.broadcast_mul(&self.w)
286    }
287}
288
289#[derive(Debug, Clone)]
290pub struct QRmsNorm {
291    eps: f64,
292    weight: Tensor,
293}
294
295impl QRmsNorm {
296    pub fn new(scale: QTensor, eps: f32) -> Result<Self> {
297        let scale = scale.dequantize(&scale.device())?;
298        Ok(Self {
299            eps: eps as f64,
300            weight: scale,
301        })
302    }
303
304    pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
305        candle_nn::ops::rms_norm(&x.contiguous()?, &self.weight, self.eps as f32)
306    }
307}
308
309/// RoPE supporting LongRope
310#[derive(Debug, Clone)]
311pub struct PhiRotaryEmbedding {
312    short_sin: Tensor,
313    short_cos: Tensor,
314    long_cos: Option<Tensor>,
315    long_sin: Option<Tensor>,
316    original_max_position_embeddings: usize,
317}
318
319#[derive(Debug, Clone, Deserialize, Serialize)]
320#[serde(rename_all = "lowercase")]
321pub enum ScaledRopeType {
322    #[serde(alias = "su")]
323    #[serde(alias = "longrope")]
324    Su,
325    #[serde(alias = "yarn")]
326    Yarn,
327    #[serde(alias = "dynamic")]
328    Dynamic,
329    #[serde(alias = "linear")]
330    Linear,
331}
332
333impl FromStr for ScaledRopeType {
334    type Err = candle_core::Error;
335    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
336        match s {
337            "su" | "longrope" => Ok(Self::Su),
338            "yarn" => Ok(Self::Yarn),
339            "linear" => Ok(Self::Linear),
340            "dynamic" => Ok(Self::Dynamic),
341            _ => Err(candle_core::Error::Msg(
342                "Expected either `su` or `yarn` scaled RoPE type.".to_string(),
343            )),
344        }
345    }
346}
347
348#[derive(Debug, Clone, Deserialize, Serialize)]
349#[serde(untagged)]
350pub enum PhiRopeScalingConfig {
351    Classic {
352        short_factor: Vec<f64>,
353        long_factor: Vec<f64>,
354        #[serde(rename = "type")]
355        scaling_type: ScaledRopeType,
356    },
357    Scaled {
358        short_factor: Vec<f64>,
359        long_factor: Vec<f64>,
360        #[serde(rename = "type")]
361        scaling_type: ScaledRopeType,
362        long_mscale: f64,
363        short_mscale: f64,
364    },
365}
366
367pub struct PhiRopeConfig {
368    pub rope_scaling: Option<PhiRopeScalingConfig>,
369    pub max_position_embeddings: usize,
370    pub original_max_position_embeddings: usize,
371    pub rope_theta: f64,
372    pub head_dim: usize,
373    pub partial_rotary_factor: Option<f64>,
374}
375
376impl PhiRotaryEmbedding {
377    fn new_classic_scaled(
378        short_factor: &[f64],
379        long_factor: &[f64],
380        scaling_type: &ScaledRopeType,
381        cfg: &PhiRopeConfig,
382        dtype: DType,
383        dev: &Device,
384    ) -> Result<Self> {
385        let max_seq_len = cfg.max_position_embeddings;
386        let dim = (cfg.head_dim as f64 * cfg.partial_rotary_factor.unwrap_or(1.)) as usize;
387
388        // Calculate scale
389        let scale =
390            cfg.max_position_embeddings as f64 / cfg.original_max_position_embeddings as f64;
391        let scaling_factor = if scale <= 1.0 {
392            1.0
393        } else {
394            match scaling_type {
395                ScaledRopeType::Su => {
396                    (1.0 + scale.ln() / (cfg.original_max_position_embeddings as f64).ln()).sqrt()
397                }
398                ScaledRopeType::Yarn => 0.1 * scale.ln() + 1.0,
399                _ => candle_core::bail!("Expected either `su` or `yarn` RoPE"),
400            }
401        };
402
403        // Calculate inv freqs for short, long
404        let inv_freq_long = (0..dim)
405            .step_by(2)
406            .enumerate()
407            .map(|(k, i)| {
408                (1f64 / (long_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64))) as f32
409            })
410            .collect::<Vec<_>>();
411        let inv_freq_short = (0..dim)
412            .step_by(2)
413            .enumerate()
414            .map(|(k, i)| {
415                (1f64 / (short_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64))) as f32
416            })
417            .collect::<Vec<_>>();
418        let inv_freq_len = inv_freq_long.len();
419
420        let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
421            .to_dtype(DType::F32)?
422            .reshape((max_seq_len, 1))?;
423
424        // Calculate sin,cos for long
425        let inv_freq_long = Tensor::from_vec(inv_freq_long, (1, inv_freq_len), dev)?;
426        let freqs_long = t.matmul(&inv_freq_long)?;
427        let long_sin = freqs_long.sin()?.mul(scaling_factor)?.to_dtype(dtype)?;
428        let long_cos = freqs_long.cos()?.mul(scaling_factor)?.to_dtype(dtype)?;
429
430        // Calculate sin,cos for short
431        let inv_freq_short =
432            Tensor::from_vec(inv_freq_short, (1, inv_freq_len), dev)?.to_dtype(DType::F32)?;
433        let freqs_short = t.matmul(&inv_freq_short)?;
434        let short_sin = freqs_short.sin()?.mul(scaling_factor)?.to_dtype(dtype)?;
435        let short_cos = freqs_short.cos()?.mul(scaling_factor)?.to_dtype(dtype)?;
436
437        Ok(Self {
438            short_cos,
439            short_sin,
440            long_cos: Some(long_cos),
441            long_sin: Some(long_sin),
442            original_max_position_embeddings: cfg.original_max_position_embeddings,
443        })
444    }
445
446    fn new_unscaled(cfg: &PhiRopeConfig, dtype: DType, dev: &Device) -> Result<Self> {
447        let max_seq_len = cfg.max_position_embeddings;
448        let dim = (cfg.head_dim as f64 * cfg.partial_rotary_factor.unwrap_or(1.)) as usize;
449
450        let inv_freq: Vec<_> = (0..dim)
451            .step_by(2)
452            .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
453            .collect();
454        let inv_freq_len = inv_freq.len();
455        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
456        let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
457            .to_dtype(DType::F32)?
458            .reshape((max_seq_len, 1))?;
459        let freqs = t.matmul(&inv_freq)?;
460        let sin = freqs.sin()?.to_dtype(dtype)?;
461        let cos = freqs.cos()?.to_dtype(dtype)?;
462        Ok(Self {
463            short_cos: cos,
464            short_sin: sin,
465            long_cos: None,
466            long_sin: None,
467            original_max_position_embeddings: cfg.original_max_position_embeddings,
468        })
469    }
470
471    #[allow(clippy::too_many_arguments)]
472    fn new_scaled(
473        short_factor: &[f64],
474        long_factor: &[f64],
475        scaling_type: &ScaledRopeType,
476        long_mscale: f64,
477        short_mscale: f64,
478        cfg: &PhiRopeConfig,
479        dtype: DType,
480        dev: &Device,
481    ) -> Result<Self> {
482        let max_seq_len = cfg.max_position_embeddings;
483        let dim = (cfg.head_dim as f64 * cfg.partial_rotary_factor.unwrap_or(1.)) as usize;
484
485        if !matches!(scaling_type, ScaledRopeType::Su) {
486            candle_core::bail!("Scaled Phi3 RoPE (non-classic scaled, with mscales) must have type `su`/`longrope`.");
487        }
488
489        if short_factor.len() != dim / 2 {
490            candle_core::bail!(
491                "Misaligned length {}, expected {} for `su`/`longrope` short rescale factors",
492                short_factor.len(),
493                dim / 2
494            );
495        }
496        if long_factor.len() != dim / 2 {
497            candle_core::bail!(
498                "Misaligned length {}, expected {} for `su`/`longrope` long rescale factors",
499                long_factor.len(),
500                dim / 2
501            );
502        }
503
504        // Short cos/sin
505        let inv_freq_short: Vec<_> = (0..dim)
506            .step_by(2)
507            .enumerate()
508            .map(|(k, i)| {
509                1f32 / (short_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32
510            })
511            .collect();
512        let inv_freq_len_short = inv_freq_short.len();
513        let inv_freq_short = Tensor::from_vec(inv_freq_short, (1, inv_freq_len_short), dev)?;
514        let t_short = Tensor::arange(0u32, max_seq_len as u32, dev)?
515            .to_dtype(DType::F32)?
516            .reshape((max_seq_len, 1))?;
517        let freqs_short = t_short.matmul(&inv_freq_short)?;
518        let sin_short = (freqs_short.sin()?.to_dtype(dtype)? * short_mscale)?;
519        let cos_short = (freqs_short.cos()?.to_dtype(dtype)? * short_mscale)?;
520
521        // Long cos/sin
522        let inv_freq_long: Vec<_> = (0..dim)
523            .step_by(2)
524            .enumerate()
525            .map(|(k, i)| {
526                1f32 / (long_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32
527            })
528            .collect();
529        let inv_freq_len_long = inv_freq_long.len();
530        let inv_freq_long = Tensor::from_vec(inv_freq_long, (1, inv_freq_len_long), dev)?;
531        let t_long = Tensor::arange(0u32, max_seq_len as u32, dev)?
532            .to_dtype(DType::F32)?
533            .reshape((max_seq_len, 1))?;
534        let freqs_long = t_long.matmul(&inv_freq_long)?;
535        let sin_long = (freqs_long.sin()?.to_dtype(dtype)? * long_mscale)?;
536        let cos_long = (freqs_long.cos()?.to_dtype(dtype)? * long_mscale)?;
537        Ok(Self {
538            short_cos: cos_short,
539            short_sin: sin_short,
540            long_cos: Some(cos_long),
541            long_sin: Some(sin_long),
542            original_max_position_embeddings: cfg.original_max_position_embeddings,
543        })
544    }
545
546    pub fn new(dtype: DType, cfg: impl Into<PhiRopeConfig>, dev: &Device) -> Result<Self> {
547        let cfg: PhiRopeConfig = cfg.into();
548
549        match &cfg.rope_scaling {
550            Some(PhiRopeScalingConfig::Classic {
551                short_factor,
552                long_factor,
553                scaling_type,
554            }) => {
555                Self::new_classic_scaled(short_factor, long_factor, scaling_type, &cfg, dtype, dev)
556            }
557
558            Some(PhiRopeScalingConfig::Scaled {
559                short_factor,
560                long_factor,
561                scaling_type,
562                long_mscale,
563                short_mscale,
564            }) => Self::new_scaled(
565                short_factor,
566                long_factor,
567                scaling_type,
568                *long_mscale,
569                *short_mscale,
570                &cfg,
571                dtype,
572                dev,
573            ),
574
575            None => Self::new_unscaled(&cfg, dtype, dev),
576        }
577    }
578
579    /// Returns (sin, cos) taking into account LongRope
580    fn get_long_or_short_sin_cos(&self, position_ids: &[usize]) -> (&Tensor, &Tensor) {
581        if self.long_cos.is_none() {
582            return (&self.short_sin, &self.short_cos);
583        }
584        let seq_len = position_ids.iter().max().unwrap() + 1;
585        if seq_len > self.original_max_position_embeddings {
586            (
587                self.long_sin.as_ref().unwrap(),
588                self.long_cos.as_ref().unwrap(),
589            )
590        } else {
591            (&self.short_sin, &self.short_cos)
592        }
593    }
594
595    pub fn forward(
596        &self,
597        q: &Tensor,
598        k: &Tensor,
599        seqlen_offsets: &[usize],
600        position_ids: &[usize],
601    ) -> Result<(Tensor, Tensor)> {
602        let (sin, cos) = self.get_long_or_short_sin_cos(position_ids);
603        let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
604
605        let rot_dim = cos.dim(D::Minus1)? * 2;
606
607        // Case for Phi 3 / Phi 4 mini
608        if rot_dim != q.dim(D::Minus1)? {
609            let rot_dim = cos.dim(D::Minus1)? * 2;
610            let q_rot = q.narrow(D::Minus1, 0, rot_dim)?;
611            let q_pass = q.narrow(D::Minus1, rot_dim, q.dim(D::Minus1)? - rot_dim)?;
612            let k_rot = k.narrow(D::Minus1, 0, rot_dim)?;
613            let k_pass = k.narrow(D::Minus1, rot_dim, k.dim(D::Minus1)? - rot_dim)?;
614
615            let (q_rot, k_rot) = if seqlen_offsets.len() == 1 {
616                let cos = cos.narrow(0, seqlen_offsets[0], seq_len)?;
617                let sin = sin.narrow(0, seqlen_offsets[0], seq_len)?;
618                let q_embed = candle_nn::rotary_emb::rope(&q_rot.contiguous()?, &cos, &sin)?;
619                let k_embed = candle_nn::rotary_emb::rope(&k_rot.contiguous()?, &cos, &sin)?;
620                (q_embed, k_embed)
621            } else {
622                let mut q_embeds = Vec::new();
623                let mut k_embeds = Vec::new();
624                for (i, offset) in seqlen_offsets.iter().enumerate() {
625                    let cos = cos.narrow(0, *offset, seq_len)?;
626                    let sin = sin.narrow(0, *offset, seq_len)?;
627                    let q_embed = candle_nn::rotary_emb::rope(
628                        &q_rot.i(i)?.unsqueeze(0)?.contiguous()?,
629                        &cos,
630                        &sin,
631                    )?;
632                    let k_embed = candle_nn::rotary_emb::rope(
633                        &k_rot.i(i)?.unsqueeze(0)?.contiguous()?,
634                        &cos,
635                        &sin,
636                    )?;
637                    q_embeds.push(q_embed);
638                    k_embeds.push(k_embed);
639                }
640                let q_rot = Tensor::cat(&q_embeds, 0)?;
641                let k_rot = Tensor::cat(&k_embeds, 0)?;
642                (q_rot, k_rot)
643            };
644
645            Ok((
646                Tensor::cat(&[q_rot, q_pass], D::Minus1)?.contiguous()?,
647                Tensor::cat(&[k_rot, k_pass], D::Minus1)?.contiguous()?,
648            ))
649        } else if seqlen_offsets.len() == 1 {
650            let cos = cos.narrow(0, seqlen_offsets[0], seq_len)?;
651            let sin = sin.narrow(0, seqlen_offsets[0], seq_len)?;
652            let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
653            let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
654            Ok((q_embed, k_embed))
655        } else {
656            let mut q_embeds = Vec::new();
657            let mut k_embeds = Vec::new();
658            for (i, offset) in seqlen_offsets.iter().enumerate() {
659                let cos = cos.narrow(0, *offset, seq_len)?;
660                let sin = sin.narrow(0, *offset, seq_len)?;
661                let q_embed =
662                    candle_nn::rotary_emb::rope(&q.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
663                let k_embed =
664                    candle_nn::rotary_emb::rope(&k.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
665                q_embeds.push(q_embed);
666                k_embeds.push(k_embed);
667            }
668            Ok((Tensor::cat(&q_embeds, 0)?, Tensor::cat(&k_embeds, 0)?))
669        }
670    }
671}
672
673/// RoPE for Llama3
674#[derive(Debug, Clone)]
675pub struct Llama3RotaryEmbedding(RotaryEmbedding);
676
677#[derive(Debug, Clone, Deserialize, Serialize, Default)]
678pub enum Llama3RopeType {
679    #[serde(rename = "llama3")]
680    Llama3,
681    #[serde(rename = "linear")]
682    Linear,
683    #[default]
684    #[serde(rename = "default")]
685    Default,
686}
687
688#[derive(Debug, Clone, Deserialize, Serialize, Default)]
689pub struct Llama3RopeConfig {
690    pub factor: f32,
691    pub low_freq_factor: Option<f32>,
692    pub high_freq_factor: Option<f32>,
693    pub original_max_position_embeddings: Option<usize>,
694    pub rope_type: Llama3RopeType,
695}
696
697fn calculate_default_inv_freq(cfg: &llama::Config) -> Vec<f32> {
698    let head_dim = cfg.hidden_size / cfg.num_attention_heads;
699    (0..head_dim)
700        .step_by(2)
701        .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
702        .collect()
703}
704
705fn calculate_default_inv_freq_llama4(cfg: &llama4::TextConfig) -> Vec<f32> {
706    let head_dim = cfg.hidden_size / cfg.num_attention_heads;
707    (0..head_dim)
708        .step_by(2)
709        .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
710        .collect()
711}
712
713// https://github.com/huggingface/transformers/blob/1392a6867f40a55dfabaf306745c67627598b1af/src/transformers/modeling_rope_utils.py#L298
714impl Llama3RotaryEmbedding {
715    pub fn new_llama3(
716        dtype: DType,
717        cfg: &llama::Config,
718        dev: &Device,
719        is_gpt_neox: bool,
720    ) -> Result<Self> {
721        match &cfg.rope_scaling {
722            None
723            | Some(Llama3RopeConfig {
724                rope_type: Llama3RopeType::Default,
725                ..
726            }) => Ok(Self(RotaryEmbedding::new(
727                cfg.rope_theta,
728                cfg.hidden_size / cfg.num_attention_heads,
729                cfg.max_position_embeddings,
730                dev,
731                is_gpt_neox,
732                dtype,
733            )?)),
734            Some(Llama3RopeConfig {
735                rope_type: Llama3RopeType::Llama3,
736                factor,
737                low_freq_factor,
738                high_freq_factor,
739                original_max_position_embeddings,
740            }) => {
741                let low_freq_factor = low_freq_factor.context("low_freq_factor is required")?;
742                let high_freq_factor = high_freq_factor.context("high_freq_factor is required")?;
743                let original_max_position_embeddings = original_max_position_embeddings
744                    .context("original_max_position_embeddings is required")?;
745
746                let low_freq_wavelen = original_max_position_embeddings as f32 / low_freq_factor;
747                let high_freq_wavelen = original_max_position_embeddings as f32 / high_freq_factor;
748
749                let inv_freq = calculate_default_inv_freq(cfg)
750                    .into_iter()
751                    .map(|freq| {
752                        let wavelen = 2. * PI / freq;
753                        if wavelen < high_freq_wavelen {
754                            freq
755                        } else if wavelen > low_freq_wavelen {
756                            freq / *factor
757                        } else {
758                            let smooth = (original_max_position_embeddings as f32 / wavelen
759                                - low_freq_factor)
760                                / (high_freq_factor - low_freq_factor);
761                            (1. - smooth) * freq / *factor + smooth * freq
762                        }
763                    })
764                    .collect::<Vec<_>>();
765                let inv_freq_len = inv_freq.len();
766                let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
767                let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
768                    .to_dtype(DType::F32)?
769                    .reshape((cfg.max_position_embeddings, 1))?;
770                let freqs = t.matmul(&inv_freq)?;
771                let sin = freqs.sin()?.to_dtype(dtype)?;
772                let cos = freqs.cos()?.to_dtype(dtype)?;
773                Ok(Self(RotaryEmbedding {
774                    sin,
775                    cos,
776                    is_gpt_neox,
777                }))
778            }
779            Some(Llama3RopeConfig {
780                rope_type: Llama3RopeType::Linear,
781                factor,
782                ..
783            }) => {
784                let inv_freq_vec = calculate_default_inv_freq(cfg)
785                    .into_iter()
786                    .map(|freq| freq / *factor)
787                    .collect::<Vec<_>>();
788                let inv_freq_len = inv_freq_vec.len();
789                let inv_freq = Tensor::from_vec(inv_freq_vec, (1, inv_freq_len), dev)?;
790                let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
791                    .to_dtype(DType::F32)?
792                    .reshape((cfg.max_position_embeddings, 1))?;
793                let freqs = t.matmul(&inv_freq)?;
794                let sin = freqs.sin()?.to_dtype(dtype)?;
795                let cos = freqs.cos()?.to_dtype(dtype)?;
796                Ok(Self(RotaryEmbedding {
797                    sin,
798                    cos,
799                    is_gpt_neox,
800                }))
801            }
802        }
803    }
804
805    pub fn new_llama4(
806        dtype: DType,
807        cfg: &llama4::TextConfig,
808        dev: &Device,
809        is_gpt_neox: bool,
810    ) -> Result<Self> {
811        match &cfg.rope_scaling {
812            None
813            | Some(Llama3RopeConfig {
814                rope_type: Llama3RopeType::Default,
815                ..
816            }) => Ok(Self(RotaryEmbedding::new(
817                cfg.rope_theta,
818                cfg.hidden_size / cfg.num_attention_heads,
819                cfg.max_position_embeddings,
820                dev,
821                is_gpt_neox,
822                dtype,
823            )?)),
824            Some(Llama3RopeConfig {
825                rope_type: Llama3RopeType::Llama3,
826                factor,
827                low_freq_factor,
828                high_freq_factor,
829                original_max_position_embeddings,
830            }) => {
831                let low_freq_factor = low_freq_factor.context("low_freq_factor is required")?;
832                let high_freq_factor = high_freq_factor.context("high_freq_factor is required")?;
833                let original_max_position_embeddings = original_max_position_embeddings
834                    .context("original_max_position_embeddings is required")?;
835
836                let low_freq_wavelen = original_max_position_embeddings as f32 / low_freq_factor;
837                let high_freq_wavelen = original_max_position_embeddings as f32 / high_freq_factor;
838
839                let inv_freq = calculate_default_inv_freq_llama4(cfg)
840                    .into_iter()
841                    .map(|freq| {
842                        let wavelen = 2. * PI / freq;
843                        if wavelen < high_freq_wavelen {
844                            freq
845                        } else if wavelen > low_freq_wavelen {
846                            freq / *factor
847                        } else {
848                            let smooth = (original_max_position_embeddings as f32 / wavelen
849                                - low_freq_factor)
850                                / (high_freq_factor - low_freq_factor);
851                            (1. - smooth) * freq / *factor + smooth * freq
852                        }
853                    })
854                    .collect::<Vec<_>>();
855                let inv_freq_len = inv_freq.len();
856                let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
857                let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
858                    .to_dtype(DType::F32)?
859                    .reshape((cfg.max_position_embeddings, 1))?;
860                let freqs = t.matmul(&inv_freq)?;
861                let sin = freqs.sin()?.to_dtype(dtype)?;
862                let cos = freqs.cos()?.to_dtype(dtype)?;
863                Ok(Self(RotaryEmbedding {
864                    sin,
865                    cos,
866                    is_gpt_neox,
867                }))
868            }
869            Some(Llama3RopeConfig {
870                rope_type: Llama3RopeType::Linear,
871                factor,
872                ..
873            }) => {
874                let inv_freq_vec = calculate_default_inv_freq_llama4(cfg)
875                    .into_iter()
876                    .map(|freq| freq / *factor)
877                    .collect::<Vec<_>>();
878                let inv_freq_len = inv_freq_vec.len();
879                let inv_freq = Tensor::from_vec(inv_freq_vec, (1, inv_freq_len), dev)?;
880                let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
881                    .to_dtype(DType::F32)?
882                    .reshape((cfg.max_position_embeddings, 1))?;
883                let freqs = t.matmul(&inv_freq)?;
884                let sin = freqs.sin()?.to_dtype(dtype)?;
885                let cos = freqs.cos()?.to_dtype(dtype)?;
886                Ok(Self(RotaryEmbedding {
887                    sin,
888                    cos,
889                    is_gpt_neox,
890                }))
891            }
892        }
893    }
894
895    pub fn new_mllama3(
896        dtype: DType,
897        cfg: &MLlamaTextConfig,
898        dev: &Device,
899        is_gpt_neox: bool,
900    ) -> Result<Self> {
901        match &cfg.rope_scaling {
902            None
903            | Some(MLlamaRopeScaling {
904                rope_type: MLlamaRopeType::Default,
905                ..
906            }) => Ok(Self(RotaryEmbedding::new(
907                cfg.rope_theta,
908                cfg.hidden_size / cfg.num_attention_heads,
909                cfg.max_position_embeddings,
910                dev,
911                is_gpt_neox,
912                dtype,
913            )?)),
914            Some(MLlamaRopeScaling {
915                rope_type: MLlamaRopeType::Llama3,
916                original_max_position_embeddings,
917                factor,
918                attention_factor: _,
919                beta_fast: _,
920                beta_slow: _,
921                short_factor: _,
922                long_factor: _,
923                low_freq_factor,
924                high_freq_factor,
925            }) => {
926                let factor = factor.context("MLlama Llama3 RoPE needs `factor` parameter.")?;
927                let low_freq_factor = low_freq_factor
928                    .context("MLlama Llama3 RoPE needs `low_freq_factor` parameter.")?;
929                let high_freq_factor = high_freq_factor
930                    .context("MLlama Llama3 RoPE needs `high_freq_factor` parameter.")?;
931
932                let low_freq_wavelen = *original_max_position_embeddings as f32 / low_freq_factor;
933                let high_freq_wavelen = *original_max_position_embeddings as f32 / high_freq_factor;
934
935                let head_dim = cfg.hidden_size / cfg.num_attention_heads;
936
937                let inv_freq = (0..head_dim)
938                    .step_by(2)
939                    .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
940                    .map(|freq| {
941                        let wavelen = 2. * PI / freq;
942                        if wavelen < high_freq_wavelen {
943                            freq
944                        } else if wavelen > low_freq_wavelen {
945                            freq / factor
946                        } else {
947                            let smooth = (*original_max_position_embeddings as f32 / wavelen
948                                - low_freq_factor)
949                                / (high_freq_factor - low_freq_factor);
950                            (1. - smooth) * freq / factor + smooth * freq
951                        }
952                    })
953                    .collect::<Vec<_>>();
954                let inv_freq_len = inv_freq.len();
955                let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
956
957                let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
958                    .to_dtype(DType::F32)?
959                    .reshape((cfg.max_position_embeddings, 1))?;
960                let freqs = t.matmul(&inv_freq)?;
961                let sin = freqs.sin()?.to_dtype(dtype)?;
962                let cos = freqs.cos()?.to_dtype(dtype)?;
963                Ok(Self(RotaryEmbedding {
964                    sin,
965                    cos,
966                    is_gpt_neox,
967                }))
968            }
969            Some(MLlamaRopeScaling {
970                rope_type: other, ..
971            }) => {
972                candle_core::bail!(
973                    "MLlama doesn't support any other RoPE type than `llama3`, got {other:?}"
974                )
975            }
976        }
977    }
978
979    pub fn forward(
980        &self,
981        q: &Tensor,
982        k: &Tensor,
983        seqlen_offsets: &[usize],
984    ) -> Result<(Tensor, Tensor)> {
985        self.0.forward(q, k, seqlen_offsets)
986    }
987}
988
989/// RoPE for SmolLm3
990#[derive(Debug, Clone)]
991pub struct SmolLm3RotaryEmbedding(RotaryEmbedding);
992
993#[derive(Debug, Clone, Deserialize, Serialize, Default)]
994pub enum SmolLm3RopeType {
995    #[serde(rename = "llama3")]
996    Llama3,
997    #[serde(rename = "linear")]
998    Linear,
999    #[default]
1000    #[serde(rename = "default")]
1001    Default,
1002}
1003
1004#[derive(Debug, Clone, Deserialize, Serialize, Default)]
1005pub struct SmolLm3RopeConfig {
1006    pub factor: f32,
1007    pub low_freq_factor: Option<f32>,
1008    pub high_freq_factor: Option<f32>,
1009    pub original_max_position_embeddings: Option<usize>,
1010    pub rope_type: SmolLm3RopeType,
1011}
1012
1013fn calculate_default_inv_freq_smollm3(cfg: &smollm3::Config) -> Vec<f32> {
1014    let head_dim = cfg.hidden_size / cfg.num_attention_heads;
1015    (0..head_dim)
1016        .step_by(2)
1017        .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
1018        .collect()
1019}
1020
1021impl SmolLm3RotaryEmbedding {
1022    pub fn new_llama3(
1023        dtype: DType,
1024        cfg: &smollm3::Config,
1025        dev: &Device,
1026        is_gpt_neox: bool,
1027    ) -> Result<Self> {
1028        match &cfg.rope_scaling {
1029            None
1030            | Some(SmolLm3RopeConfig {
1031                rope_type: SmolLm3RopeType::Default,
1032                ..
1033            }) => Ok(Self(RotaryEmbedding::new(
1034                cfg.rope_theta,
1035                cfg.hidden_size / cfg.num_attention_heads,
1036                cfg.max_position_embeddings,
1037                dev,
1038                is_gpt_neox,
1039                dtype,
1040            )?)),
1041            Some(SmolLm3RopeConfig {
1042                rope_type: SmolLm3RopeType::Llama3,
1043                factor,
1044                low_freq_factor,
1045                high_freq_factor,
1046                original_max_position_embeddings,
1047            }) => {
1048                let low_freq_factor = low_freq_factor.context("low_freq_factor is required")?;
1049                let high_freq_factor = high_freq_factor.context("high_freq_factor is required")?;
1050                let original_max_position_embeddings = original_max_position_embeddings
1051                    .context("original_max_position_embeddings is required")?;
1052
1053                let low_freq_wavelen = original_max_position_embeddings as f32 / low_freq_factor;
1054                let high_freq_wavelen = original_max_position_embeddings as f32 / high_freq_factor;
1055
1056                let inv_freq = calculate_default_inv_freq_smollm3(cfg)
1057                    .into_iter()
1058                    .map(|freq| {
1059                        let wavelen = 2. * PI / freq;
1060                        if wavelen < high_freq_wavelen {
1061                            freq
1062                        } else if wavelen > low_freq_wavelen {
1063                            freq / *factor
1064                        } else {
1065                            let smooth = (original_max_position_embeddings as f32 / wavelen
1066                                - low_freq_factor)
1067                                / (high_freq_factor - low_freq_factor);
1068                            (1. - smooth) * freq / *factor + smooth * freq
1069                        }
1070                    })
1071                    .collect::<Vec<_>>();
1072                let inv_freq_len = inv_freq.len();
1073                let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
1074                let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
1075                    .to_dtype(DType::F32)?
1076                    .reshape((cfg.max_position_embeddings, 1))?;
1077                let freqs = t.matmul(&inv_freq)?;
1078                let sin = freqs.sin()?.to_dtype(dtype)?;
1079                let cos = freqs.cos()?.to_dtype(dtype)?;
1080                Ok(Self(RotaryEmbedding {
1081                    sin,
1082                    cos,
1083                    is_gpt_neox,
1084                }))
1085            }
1086            Some(SmolLm3RopeConfig {
1087                rope_type: SmolLm3RopeType::Linear,
1088                factor,
1089                ..
1090            }) => {
1091                let inv_freq_vec = calculate_default_inv_freq_smollm3(cfg)
1092                    .into_iter()
1093                    .map(|freq| freq / *factor)
1094                    .collect::<Vec<_>>();
1095                let inv_freq_len = inv_freq_vec.len();
1096                let inv_freq = Tensor::from_vec(inv_freq_vec, (1, inv_freq_len), dev)?;
1097                let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
1098                    .to_dtype(DType::F32)?
1099                    .reshape((cfg.max_position_embeddings, 1))?;
1100                let freqs = t.matmul(&inv_freq)?;
1101                let sin = freqs.sin()?.to_dtype(dtype)?;
1102                let cos = freqs.cos()?.to_dtype(dtype)?;
1103                Ok(Self(RotaryEmbedding {
1104                    sin,
1105                    cos,
1106                    is_gpt_neox,
1107                }))
1108            }
1109        }
1110    }
1111
1112    pub fn forward(
1113        &self,
1114        q: &Tensor,
1115        k: &Tensor,
1116        seqlen_offsets: &[usize],
1117    ) -> Result<(Tensor, Tensor)> {
1118        self.0.forward(q, k, seqlen_offsets)
1119    }
1120}
1121
1122// https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L107
1123#[derive(Debug, Clone)]
1124pub struct Qwen2VLRotaryEmbedding {
1125    inv_freq: Tensor,
1126    mrope_section: Vec<usize>,
1127}
1128
1129impl Qwen2VLRotaryEmbedding {
1130    pub fn new(
1131        base: f32,
1132        head_dim: usize,
1133        device: &Device,
1134        mrope_section: Vec<usize>,
1135    ) -> Result<Self> {
1136        let inv_freq: Vec<_> = (0..head_dim)
1137            .step_by(2)
1138            .map(|i| 1f32 / base.powf(i as f32 / head_dim as f32))
1139            .collect();
1140        let inv_freq_len = inv_freq.len();
1141        let inv_freq = Tensor::from_vec(inv_freq, (inv_freq_len,), device)?.to_dtype(DType::F32)?;
1142        Ok(Self {
1143            inv_freq,
1144            mrope_section,
1145        })
1146    }
1147
1148    /// (cos, sin)
1149    pub fn compute_cos_sin(&self, position_ids: &Tensor, dtype: DType) -> Result<(Tensor, Tensor)> {
1150        let inv_freq_expanded =
1151            self.inv_freq
1152                .reshape((1, 1, (), 1))?
1153                .repeat((3, position_ids.dim(1)?, 1, 1))?;
1154        let position_ids_expanded = position_ids.unsqueeze(2)?;
1155        let freqs = inv_freq_expanded
1156            .matmul(&position_ids_expanded.to_dtype(inv_freq_expanded.dtype())?)?
1157            .transpose(2, 3)?;
1158        let cos = freqs.cos()?;
1159        let sin = freqs.sin()?;
1160
1161        let cos = Tensor::cat(
1162            &cos.split(&self.mrope_section, D::Minus1)?
1163                .into_iter()
1164                .enumerate()
1165                .map(|(i, m)| m.i(i % 3))
1166                .collect::<Result<Vec<_>>>()?,
1167            D::Minus1,
1168        )?
1169        .squeeze(0)?
1170        .to_dtype(dtype)?
1171        .contiguous()?;
1172        let sin = Tensor::cat(
1173            &sin.split(&self.mrope_section, D::Minus1)?
1174                .into_iter()
1175                .enumerate()
1176                .map(|(i, m)| m.i(i % 3))
1177                .collect::<Result<Vec<_>>>()?,
1178            D::Minus1,
1179        )?
1180        .squeeze(0)?
1181        .to_dtype(dtype)?
1182        .contiguous()?;
1183
1184        Ok((cos, sin))
1185    }
1186
1187    // https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L203
1188    pub fn forward(
1189        &self,
1190        (cos, sin): &(Tensor, Tensor),
1191        q: &mut Tensor,
1192        k: &mut Tensor,
1193    ) -> Result<()> {
1194        *q = candle_nn::rotary_emb::rope(&q.contiguous()?, cos, sin)?;
1195        *k = candle_nn::rotary_emb::rope(&k.contiguous()?, cos, sin)?;
1196        Ok(())
1197    }
1198}
1199
1200#[derive(Debug, Clone)]
1201pub struct Qwen2_5VLRotaryEmbedding {
1202    inv_freq: Tensor,
1203    mrope_section: Vec<usize>,
1204}
1205
1206impl Qwen2_5VLRotaryEmbedding {
1207    pub fn new(
1208        base: f32,
1209        head_dim: usize,
1210        device: &Device,
1211        mrope_section: Vec<usize>,
1212    ) -> Result<Self> {
1213        let inv_freq: Vec<_> = (0..head_dim)
1214            .step_by(2)
1215            .map(|i| 1f32 / base.powf(i as f32 / head_dim as f32))
1216            .collect();
1217        let inv_freq_len = inv_freq.len();
1218        let inv_freq = Tensor::from_vec(inv_freq, (inv_freq_len,), device)?.to_dtype(DType::F32)?;
1219        Ok(Self {
1220            inv_freq,
1221            mrope_section,
1222        })
1223    }
1224
1225    /// (cos, sin)
1226    pub fn compute_cos_sin(&self, position_ids: &Tensor, dtype: DType) -> Result<(Tensor, Tensor)> {
1227        let inv_freq_expanded =
1228            self.inv_freq
1229                .reshape((1, 1, (), 1))?
1230                .repeat((3, position_ids.dim(1)?, 1, 1))?;
1231        let position_ids_expanded = position_ids.unsqueeze(2)?;
1232        let freqs = inv_freq_expanded
1233            .matmul(&position_ids_expanded.to_dtype(inv_freq_expanded.dtype())?)?
1234            .transpose(2, 3)?;
1235        let cos = freqs.cos()?;
1236        let sin = freqs.sin()?;
1237
1238        let cos = Tensor::cat(
1239            &cos.split(&self.mrope_section, D::Minus1)?
1240                .into_iter()
1241                .enumerate()
1242                .map(|(i, m)| m.i(i % 3))
1243                .collect::<Result<Vec<_>>>()?,
1244            D::Minus1,
1245        )?
1246        .squeeze(0)?
1247        .to_dtype(dtype)?
1248        .contiguous()?;
1249        let sin = Tensor::cat(
1250            &sin.split(&self.mrope_section, D::Minus1)?
1251                .into_iter()
1252                .enumerate()
1253                .map(|(i, m)| m.i(i % 3))
1254                .collect::<Result<Vec<_>>>()?,
1255            D::Minus1,
1256        )?
1257        .squeeze(0)?
1258        .to_dtype(dtype)?
1259        .contiguous()?;
1260
1261        Ok((cos, sin))
1262    }
1263
1264    pub fn forward(
1265        &self,
1266        (cos, sin): &(Tensor, Tensor),
1267        q: &mut Tensor,
1268        k: &mut Tensor,
1269    ) -> Result<()> {
1270        *q = candle_nn::rotary_emb::rope(&q.contiguous()?, cos, sin)?;
1271        *k = candle_nn::rotary_emb::rope(&k.contiguous()?, cos, sin)?;
1272        Ok(())
1273    }
1274}
1275
1276#[derive(Debug, Clone)]
1277pub struct DeepSeekV2RotaryEmbedding {
1278    sin: Tensor,
1279    cos: Tensor,
1280}
1281
1282#[derive(Debug, Clone, Deserialize, Serialize)]
1283#[serde(untagged)]
1284pub enum DeepSeekV2RopeScaling {
1285    Yarn {
1286        original_max_position_embeddings: usize,
1287        beta_fast: f32,
1288        beta_slow: f32,
1289        mscale: f32,
1290        mscale_all_dim: f32,
1291        factor: f32,
1292        #[serde(rename = "type")]
1293        scaling_type: ScaledRopeType,
1294    },
1295    LinearOrDynamic {
1296        #[serde(rename = "type")]
1297        scaling_type: ScaledRopeType,
1298        factor: f64,
1299    },
1300}
1301
1302pub struct DeepSeekV2RopeConfig {
1303    pub rope_scaling: Option<DeepSeekV2RopeScaling>,
1304    pub max_position_embeddings: usize,
1305    pub rope_theta: f32,
1306    pub qk_rope_head_dim: usize,
1307}
1308
1309impl DeepSeekV2RotaryEmbedding {
1310    fn new_unscaled(cfg: &DeepSeekV2RopeConfig, dtype: DType, dev: &Device) -> Result<Self> {
1311        let max_seq_len = cfg.max_position_embeddings;
1312        let dim = cfg.qk_rope_head_dim;
1313
1314        let inv_freq: Vec<_> = (0..dim)
1315            .step_by(2)
1316            .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / dim as f32))
1317            .collect();
1318        let inv_freq_len = inv_freq.len();
1319        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
1320        let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
1321            .to_dtype(DType::F32)?
1322            .reshape((max_seq_len, 1))?;
1323        let freqs = t.matmul(&inv_freq)?;
1324
1325        let sin = freqs.sin()?.to_dtype(dtype)?;
1326        let cos = freqs.cos()?.to_dtype(dtype)?;
1327
1328        Ok(Self { sin, cos })
1329    }
1330
1331    fn yarn_find_correction_dim(
1332        num_rot: f32,
1333        dim: usize,
1334        base: f32,
1335        max_position_embeddings: usize,
1336    ) -> f32 {
1337        (dim as f32 * (max_position_embeddings as f32 / (num_rot * 2. * PI)).ln())
1338            / (2. * base.ln())
1339    }
1340
1341    fn yarn_find_correction_range(
1342        low_rot: f32,
1343        high_rot: f32,
1344        dim: usize,
1345        base: f32,
1346        max_position_embeddings: usize,
1347    ) -> (f32, f32) {
1348        let low =
1349            Self::yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings).floor();
1350        let high =
1351            Self::yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings).ceil();
1352        (low.max(0.), high.min(dim as f32 - 1.))
1353    }
1354
1355    fn yarn_linear_ramp_mask(min: f32, mut max: f32, dim: usize, dev: &Device) -> Result<Tensor> {
1356        if min == max {
1357            // https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite/blob/604d5664dddd88a0433dbae533b7fe9472482de0/modeling_deepseek.py#L255
1358            max += 0.001;
1359        }
1360        let linear_func =
1361            ((Tensor::arange(0f32, dim as f32, dev)? - min as f64)? / (max as f64 - min as f64))?;
1362        linear_func.clamp(0., 1)
1363    }
1364
1365    pub(crate) fn yarn_get_mscale(scale: f32, mscale: f32) -> f32 {
1366        if scale <= 1. {
1367            return 1.;
1368        }
1369        0.1 * mscale * scale.ln() + 1.
1370    }
1371
1372    #[allow(clippy::too_many_arguments)]
1373    fn new_yarn(
1374        cfg: &DeepSeekV2RopeConfig,
1375        dtype: DType,
1376        dev: &Device,
1377        original_max_position_embeddings: usize,
1378        beta_fast: f32,
1379        beta_slow: f32,
1380        factor: f32,
1381        mscale: f32,
1382        mscale_all_dim: f32,
1383    ) -> Result<Self> {
1384        let freq_extra: Vec<_> = (0..cfg.qk_rope_head_dim)
1385            .step_by(2)
1386            .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / cfg.qk_rope_head_dim as f32))
1387            .collect();
1388        let freq_extra_len = freq_extra.len();
1389        let freq_extra = Tensor::from_vec(freq_extra, freq_extra_len, dev)?;
1390        let freq_inter: Vec<_> = (0..cfg.qk_rope_head_dim)
1391            .step_by(2)
1392            .map(|i| 1f32 / (factor * cfg.rope_theta.powf(i as f32 / cfg.qk_rope_head_dim as f32)))
1393            .collect();
1394        let freq_inter_len = freq_inter.len();
1395        let freq_inter = Tensor::from_vec(freq_inter, (1, freq_inter_len), dev)?;
1396
1397        let (low, high) = Self::yarn_find_correction_range(
1398            beta_fast,
1399            beta_slow,
1400            cfg.qk_rope_head_dim,
1401            cfg.rope_theta,
1402            original_max_position_embeddings,
1403        );
1404        let inv_freq_mask =
1405            (1. - Self::yarn_linear_ramp_mask(low, high, cfg.qk_rope_head_dim / 2, dev)?)?;
1406        let inv_freq = freq_inter
1407            .broadcast_mul(&(1. - &inv_freq_mask)?)?
1408            .broadcast_add(&freq_extra.broadcast_mul(&inv_freq_mask)?)?;
1409
1410        let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
1411            .to_dtype(DType::F32)?
1412            .reshape((cfg.max_position_embeddings, 1))?;
1413        let freqs = t.matmul(&inv_freq)?;
1414
1415        let mscale =
1416            Self::yarn_get_mscale(factor, mscale) / Self::yarn_get_mscale(factor, mscale_all_dim);
1417        let sin = (freqs.sin()? * mscale as f64)?.to_dtype(dtype)?;
1418        let cos = (freqs.cos()? * mscale as f64)?.to_dtype(dtype)?;
1419
1420        Ok(Self { sin, cos })
1421    }
1422
1423    pub fn new(cfg: &DeepSeekV2RopeConfig, dtype: DType, dev: &Device) -> Result<Self> {
1424        match &cfg.rope_scaling {
1425            Some(DeepSeekV2RopeScaling::LinearOrDynamic {
1426                scaling_type: _,
1427                factor: _,
1428            }) => candle_core::bail!("linear and dynamic rope are not implemented yet!"),
1429            Some(DeepSeekV2RopeScaling::Yarn {
1430                original_max_position_embeddings,
1431                beta_fast,
1432                beta_slow,
1433                factor,
1434                mscale,
1435                mscale_all_dim,
1436                scaling_type: _,
1437            }) => Self::new_yarn(
1438                cfg,
1439                dtype,
1440                dev,
1441                *original_max_position_embeddings,
1442                *beta_fast,
1443                *beta_slow,
1444                *factor,
1445                *mscale,
1446                *mscale_all_dim,
1447            ),
1448            None => Self::new_unscaled(cfg, dtype, dev),
1449        }
1450    }
1451
1452    pub fn forward(
1453        &self,
1454        q: &Tensor,
1455        k: &Tensor,
1456        seqlen_offsets: &[usize],
1457    ) -> Result<(Tensor, Tensor)> {
1458        let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
1459
1460        if seqlen_offsets.len() == 1 {
1461            let cos = self.cos.narrow(0, seqlen_offsets[0], seq_len)?;
1462            let sin = self.sin.narrow(0, seqlen_offsets[0], seq_len)?;
1463            let q_embed = candle_nn::rotary_emb::rope_i(&q.contiguous()?, &cos, &sin)?;
1464            let k_embed = candle_nn::rotary_emb::rope_i(&k.contiguous()?, &cos, &sin)?;
1465            Ok((q_embed, k_embed))
1466        } else {
1467            let mut q_embeds = Vec::new();
1468            let mut k_embeds = Vec::new();
1469            for (i, offset) in seqlen_offsets.iter().enumerate() {
1470                let cos = self.cos.narrow(0, *offset, seq_len)?;
1471                let sin = self.sin.narrow(0, *offset, seq_len)?;
1472                let q_embed = candle_nn::rotary_emb::rope_i(
1473                    &q.i(i)?.unsqueeze(0)?.contiguous()?,
1474                    &cos,
1475                    &sin,
1476                )?;
1477                let k_embed = candle_nn::rotary_emb::rope_i(
1478                    &k.i(i)?.unsqueeze(0)?.contiguous()?,
1479                    &cos,
1480                    &sin,
1481                )?;
1482                q_embeds.push(q_embed);
1483                k_embeds.push(k_embed);
1484            }
1485            Ok((Tensor::cat(&q_embeds, 0)?, Tensor::cat(&k_embeds, 0)?))
1486        }
1487    }
1488}
1489
1490#[derive(Debug, Clone)]
1491pub struct Phi4MMRotaryEmbedding {
1492    short_sin: Tensor,
1493    short_cos: Tensor,
1494    long_cos: Option<Tensor>,
1495    long_sin: Option<Tensor>,
1496    original_max_position_embeddings: usize,
1497}
1498
1499#[derive(Debug, Clone, Default, Deserialize, Serialize)]
1500#[serde(rename_all = "lowercase")]
1501pub enum Phi4MMScaledRopeType {
1502    #[serde(alias = "longrope")]
1503    LongRope,
1504    #[default]
1505    Default,
1506}
1507
1508#[derive(Debug, Clone, Deserialize, Serialize)]
1509pub struct Phi4MMRopeScalingConfig {
1510    short_factor: Option<Vec<f64>>,
1511    long_factor: Option<Vec<f64>>,
1512    #[serde(rename = "type")]
1513    scaling_type: Phi4MMScaledRopeType,
1514}
1515
1516impl Phi4MMRotaryEmbedding {
1517    fn new_unscaled(cfg: &Phi4MMConfig, dtype: DType, dev: &Device) -> Result<Self> {
1518        let max_seq_len = cfg.max_position_embeddings;
1519        let dim = (cfg.head_dim() as f64 * cfg.partial_rotary_factor) as usize;
1520
1521        let inv_freq: Vec<_> = (0..dim)
1522            .step_by(2)
1523            .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
1524            .collect();
1525        let inv_freq_len = inv_freq.len();
1526        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
1527        let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
1528            .to_dtype(DType::F32)?
1529            .reshape((max_seq_len, 1))?;
1530        let freqs = t.matmul(&inv_freq)?;
1531        let sin = freqs.sin()?.to_dtype(dtype)?;
1532        let cos = freqs.cos()?.to_dtype(dtype)?;
1533        Ok(Self {
1534            short_cos: cos,
1535            short_sin: sin,
1536            long_cos: None,
1537            long_sin: None,
1538            original_max_position_embeddings: cfg.original_max_position_embeddings,
1539        })
1540    }
1541
1542    #[allow(clippy::too_many_arguments)]
1543    fn new_longrope(
1544        short_factor: &[f64],
1545        long_factor: &[f64],
1546        cfg: &Phi4MMConfig,
1547        dtype: DType,
1548        dev: &Device,
1549    ) -> Result<Self> {
1550        let max_seq_len = cfg.max_position_embeddings;
1551        let dim = (cfg.head_dim() as f64 * cfg.partial_rotary_factor) as usize;
1552
1553        // Calculate scale
1554        let scale =
1555            cfg.max_position_embeddings as f64 / cfg.original_max_position_embeddings as f64;
1556        let scaling_factor = if scale <= 1.0 {
1557            1.0
1558        } else {
1559            (1.0 + scale.ln() / (cfg.original_max_position_embeddings as f64).ln()).sqrt()
1560        };
1561
1562        // Short cos/sin
1563        let inv_freq_short: Vec<_> = (0..dim)
1564            .step_by(2)
1565            .enumerate()
1566            .map(|(k, i)| {
1567                1f32 / (short_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32
1568            })
1569            .collect();
1570        let inv_freq_len_short = inv_freq_short.len();
1571        let inv_freq_short = Tensor::from_vec(inv_freq_short, (1, inv_freq_len_short), dev)?;
1572        let t_short = Tensor::arange(0u32, max_seq_len as u32, dev)?
1573            .to_dtype(DType::F32)?
1574            .reshape((max_seq_len, 1))?;
1575        let freqs_short = t_short.matmul(&inv_freq_short)?;
1576        let sin_short = (freqs_short.sin()?.to_dtype(dtype)? * scaling_factor)?;
1577        let cos_short = (freqs_short.cos()?.to_dtype(dtype)? * scaling_factor)?;
1578
1579        // Long cos/sin
1580        let inv_freq_long: Vec<_> = (0..dim)
1581            .step_by(2)
1582            .enumerate()
1583            .map(|(k, i)| {
1584                1f32 / (long_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32
1585            })
1586            .collect();
1587        let inv_freq_len_long = inv_freq_long.len();
1588        let inv_freq_long = Tensor::from_vec(inv_freq_long, (1, inv_freq_len_long), dev)?;
1589        let t_long = Tensor::arange(0u32, max_seq_len as u32, dev)?
1590            .to_dtype(DType::F32)?
1591            .reshape((max_seq_len, 1))?;
1592        let freqs_long = t_long.matmul(&inv_freq_long)?;
1593        let sin_long = (freqs_long.sin()?.to_dtype(dtype)? * scaling_factor)?;
1594        let cos_long = (freqs_long.cos()?.to_dtype(dtype)? * scaling_factor)?;
1595
1596        Ok(Self {
1597            short_cos: cos_short,
1598            short_sin: sin_short,
1599            long_cos: Some(cos_long),
1600            long_sin: Some(sin_long),
1601            original_max_position_embeddings: cfg.original_max_position_embeddings,
1602        })
1603    }
1604
1605    pub fn new(dtype: DType, cfg: &Phi4MMConfig, dev: &Device) -> Result<Self> {
1606        match &cfg.rope_scaling {
1607            Some(Phi4MMRopeScalingConfig {
1608                scaling_type: Phi4MMScaledRopeType::LongRope,
1609                short_factor: Some(short_factor),
1610                long_factor: Some(long_factor),
1611            }) => Self::new_longrope(short_factor, long_factor, cfg, dtype, dev),
1612
1613            _ => Self::new_unscaled(cfg, dtype, dev),
1614        }
1615    }
1616
1617    /// Returns (sin, cos) taking into account LongRope
1618    fn get_long_or_short_sin_cos(&self, position_ids: &[usize]) -> (&Tensor, &Tensor) {
1619        if self.long_cos.is_none() {
1620            return (&self.short_sin, &self.short_cos);
1621        }
1622        let seq_len = position_ids.iter().max().unwrap() + 1;
1623        if seq_len > self.original_max_position_embeddings {
1624            (
1625                self.long_sin.as_ref().unwrap(),
1626                self.long_cos.as_ref().unwrap(),
1627            )
1628        } else {
1629            (&self.short_sin, &self.short_cos)
1630        }
1631    }
1632
1633    pub fn forward(
1634        &self,
1635        q: &Tensor,
1636        k: &Tensor,
1637        seqlen_offsets: &[usize],
1638        position_ids: &[usize],
1639    ) -> Result<(Tensor, Tensor)> {
1640        let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
1641        let (sin, cos) = self.get_long_or_short_sin_cos(position_ids);
1642
1643        let rot_dim = cos.dim(D::Minus1)? * 2;
1644        let q_rot = q.narrow(D::Minus1, 0, rot_dim)?;
1645        let q_pass = q.narrow(D::Minus1, rot_dim, q.dim(D::Minus1)? - rot_dim)?;
1646        let k_rot = k.narrow(D::Minus1, 0, rot_dim)?;
1647        let k_pass = k.narrow(D::Minus1, rot_dim, k.dim(D::Minus1)? - rot_dim)?;
1648
1649        let (q_rot, k_rot) = if seqlen_offsets.len() == 1 {
1650            let cos = cos.narrow(0, seqlen_offsets[0], seq_len)?;
1651            let sin = sin.narrow(0, seqlen_offsets[0], seq_len)?;
1652            let q_embed = candle_nn::rotary_emb::rope(&q_rot.contiguous()?, &cos, &sin)?;
1653            let k_embed = candle_nn::rotary_emb::rope(&k_rot.contiguous()?, &cos, &sin)?;
1654            (q_embed, k_embed)
1655        } else {
1656            let mut q_embeds = Vec::new();
1657            let mut k_embeds = Vec::new();
1658            for (i, offset) in seqlen_offsets.iter().enumerate() {
1659                let cos = cos.narrow(0, *offset, seq_len)?;
1660                let sin = sin.narrow(0, *offset, seq_len)?;
1661                let q_embed = candle_nn::rotary_emb::rope(
1662                    &q_rot.i(i)?.unsqueeze(0)?.contiguous()?,
1663                    &cos,
1664                    &sin,
1665                )?;
1666                let k_embed = candle_nn::rotary_emb::rope(
1667                    &k_rot.i(i)?.unsqueeze(0)?.contiguous()?,
1668                    &cos,
1669                    &sin,
1670                )?;
1671                q_embeds.push(q_embed);
1672                k_embeds.push(k_embed);
1673            }
1674            let q_rot = Tensor::cat(&q_embeds, 0)?;
1675            let k_rot = Tensor::cat(&k_embeds, 0)?;
1676            (q_rot, k_rot)
1677        };
1678
1679        Ok((
1680            Tensor::cat(&[q_rot, q_pass], D::Minus1)?.contiguous()?,
1681            Tensor::cat(&[k_rot, k_pass], D::Minus1)?.contiguous()?,
1682        ))
1683    }
1684}
1685
1686#[derive(Debug, Clone)]
1687pub struct Gemma3nRotaryEmbedding(RotaryEmbedding);
1688
1689#[derive(Debug, Clone, Deserialize, Serialize)]
1690#[serde(rename_all = "lowercase")]
1691pub enum Gemma3nScaledRopeType {
1692    #[serde(alias = "linear")]
1693    Linear,
1694}
1695
1696#[derive(Debug, Clone, Deserialize, Serialize)]
1697pub struct Gemma3nRopeScalingConfig {
1698    factor: f64,
1699    rope_type: Gemma3nScaledRopeType,
1700}
1701
1702impl Gemma3nRotaryEmbedding {
1703    fn new_linear(
1704        cfg: &Gemma3nTextConfig,
1705        factor: f64,
1706        is_gpt_neox: bool,
1707        dtype: DType,
1708        dev: &Device,
1709    ) -> Result<Self> {
1710        let max_seq_len = cfg.max_position_embeddings;
1711        let dim = cfg.head_dim;
1712
1713        let inv_freq: Vec<_> = (0..dim)
1714            .step_by(2)
1715            .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
1716            .collect();
1717        let inv_freq_len = inv_freq.len();
1718        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
1719        let inv_freq = (inv_freq / factor)?;
1720
1721        let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
1722            .to_dtype(DType::F32)?
1723            .reshape((max_seq_len, 1))?;
1724        let freqs = t.matmul(&inv_freq)?;
1725        let sin = freqs.sin()?.to_dtype(dtype)?;
1726        let cos = freqs.cos()?.to_dtype(dtype)?;
1727        Ok(Self(RotaryEmbedding {
1728            cos,
1729            sin,
1730            is_gpt_neox,
1731        }))
1732    }
1733
1734    pub fn new(
1735        is_gpt_neox: bool,
1736        dtype: DType,
1737        cfg: &Gemma3nTextConfig,
1738        dev: &Device,
1739    ) -> Result<Self> {
1740        match &cfg.rope_scaling {
1741            Some(Gemma3RopeScalingConfig {
1742                rope_type: Gemma3ScaledRopeType::Linear,
1743                factor,
1744            }) => Self::new_linear(cfg, *factor, is_gpt_neox, dtype, dev),
1745
1746            _ => Self::new_linear(cfg, 1.0, is_gpt_neox, dtype, dev),
1747        }
1748    }
1749
1750    pub fn get_cos_sin(&self) -> Result<(Tensor, Tensor)> {
1751        self.0.get_cos_sin()
1752    }
1753
1754    pub fn forward(
1755        &self,
1756        q: &Tensor,
1757        k: &Tensor,
1758        seqlen_offsets: &[usize],
1759    ) -> Result<(Tensor, Tensor)> {
1760        self.0.forward(q, k, seqlen_offsets)
1761    }
1762}
1763
1764#[derive(Debug, Clone)]
1765pub struct Gemma3RotaryEmbedding(RotaryEmbedding);
1766
1767#[derive(Debug, Clone, Deserialize, Serialize)]
1768#[serde(rename_all = "lowercase")]
1769pub enum Gemma3ScaledRopeType {
1770    #[serde(alias = "linear")]
1771    Linear,
1772}
1773
1774#[derive(Debug, Clone, Deserialize, Serialize)]
1775pub struct Gemma3RopeScalingConfig {
1776    factor: f64,
1777    rope_type: Gemma3ScaledRopeType,
1778}
1779
1780impl Gemma3RotaryEmbedding {
1781    fn new_linear(
1782        cfg: &Gemma3TextConfig,
1783        factor: f64,
1784        is_gpt_neox: bool,
1785        dtype: DType,
1786        dev: &Device,
1787    ) -> Result<Self> {
1788        let max_seq_len = cfg.max_position_embeddings;
1789        let dim = cfg.head_dim;
1790
1791        let inv_freq: Vec<_> = (0..dim)
1792            .step_by(2)
1793            .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
1794            .collect();
1795        let inv_freq_len = inv_freq.len();
1796        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
1797        let inv_freq = (inv_freq / factor)?;
1798
1799        let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
1800            .to_dtype(DType::F32)?
1801            .reshape((max_seq_len, 1))?;
1802        let freqs = t.matmul(&inv_freq)?;
1803        let sin = freqs.sin()?.to_dtype(dtype)?;
1804        let cos = freqs.cos()?.to_dtype(dtype)?;
1805        Ok(Self(RotaryEmbedding {
1806            cos,
1807            sin,
1808            is_gpt_neox,
1809        }))
1810    }
1811
1812    pub fn new(
1813        is_gpt_neox: bool,
1814        dtype: DType,
1815        cfg: &Gemma3TextConfig,
1816        dev: &Device,
1817    ) -> Result<Self> {
1818        match &cfg.rope_scaling {
1819            Some(Gemma3RopeScalingConfig {
1820                rope_type: Gemma3ScaledRopeType::Linear,
1821                factor,
1822            }) => Self::new_linear(cfg, *factor, is_gpt_neox, dtype, dev),
1823
1824            _ => Self::new_linear(cfg, 1.0, is_gpt_neox, dtype, dev),
1825        }
1826    }
1827
1828    pub fn forward(
1829        &self,
1830        q: &Tensor,
1831        k: &Tensor,
1832        seqlen_offsets: &[usize],
1833    ) -> Result<(Tensor, Tensor)> {
1834        self.0.forward(q, k, seqlen_offsets)
1835    }
1836}
1837
1838pub struct DiaRotaryEmbedding {
1839    timescale: Tensor,
1840    dtype: DType,
1841}
1842
1843impl DiaRotaryEmbedding {
1844    pub fn new(
1845        min_timescale: f32,
1846        max_timescale: f32,
1847        head_dim: usize,
1848        device: &Device,
1849        dtype: DType,
1850    ) -> Result<Self> {
1851        assert_eq!(head_dim % 2, 0);
1852        let half_embedding_dim = head_dim / 2;
1853
1854        let fraction = (0..half_embedding_dim).map(|i| 2f32 * i as f32 / head_dim as f32);
1855        let timescale = fraction
1856            .into_iter()
1857            .map(|x| min_timescale * (max_timescale / min_timescale).powf(x))
1858            .collect::<Vec<_>>();
1859
1860        let timescale_len = timescale.len();
1861        let timescale = Tensor::from_vec(timescale, timescale_len, device)?;
1862
1863        Ok(Self { timescale, dtype })
1864    }
1865
1866    pub fn forward(&self, xs: &Tensor, positions: &Tensor) -> Result<Tensor> {
1867        let freqs = positions
1868            .unsqueeze(D::Minus1)?
1869            .unsqueeze(D::Minus1)?
1870            .broadcast_div(&self.timescale)?;
1871
1872        let sin = freqs.sin()?.to_dtype(self.dtype)?;
1873        let cos = freqs.cos()?.to_dtype(self.dtype)?;
1874
1875        let split = xs.chunk(2, D::Minus1)?;
1876        let first_half = &split[0];
1877        let second_half = &split[1];
1878
1879        let first_part = (first_half.broadcast_mul(&cos)? - second_half.broadcast_mul(&sin)?)?;
1880        let second_part = (second_half.broadcast_mul(&cos)? + first_half.broadcast_mul(&sin)?)?;
1881
1882        Tensor::cat(&[first_part, second_part], D::Minus1)
1883    }
1884}
1885#[derive(Debug, Clone)]
1886pub struct QLinear {
1887    inner: QMatMul,
1888    bias: Option<Tensor>,
1889    dtype: DType,
1890}
1891
1892impl QLinear {
1893    pub fn new<R: std::io::Read + std::io::Seek>(
1894        ct: &mut Content<'_, R>,
1895        name: &str,
1896        device: &Device,
1897    ) -> Result<Self> {
1898        let w = ct.tensor(&format!("{name}.weight"), device)?;
1899        let b = ct.tensor(&format!("{name}.bias"), device)?;
1900        let inner = QMatMul::from_qtensor(w)?;
1901        let bias = b.dequantize(device)?;
1902        Ok(Self {
1903            inner,
1904            bias: Some(bias),
1905            dtype: DType::F32,
1906        })
1907    }
1908
1909    pub fn from_linear(linear: Linear) -> Self {
1910        Self {
1911            inner: QMatMul::Tensor(linear.weight().clone()),
1912            bias: linear.bias().cloned(),
1913            dtype: linear.weight().dtype(),
1914        }
1915    }
1916
1917    pub fn from_parts(w: Tensor, b: Option<Tensor>) -> Self {
1918        let dtype = w.dtype();
1919        Self {
1920            inner: QMatMul::Tensor(w),
1921            bias: b,
1922            dtype,
1923        }
1924    }
1925
1926    pub fn from_qparts(w: QTensor, b: Option<Tensor>) -> Self {
1927        if let Some(ref b) = b {
1928            assert_eq!(b.dtype(), DType::F32);
1929        }
1930        Self {
1931            inner: QMatMul::QTensor(Arc::new(w)),
1932            bias: b,
1933            dtype: DType::F32,
1934        }
1935    }
1936
1937    pub fn from_old_and_qmatmul(inner: QMatMul, old: &Self) -> Self {
1938        Self {
1939            inner,
1940            bias: old.bias.clone(),
1941            dtype: old.dtype,
1942        }
1943    }
1944
1945    pub fn inner(&mut self) -> &mut QMatMul {
1946        &mut self.inner
1947    }
1948
1949    pub fn inner_ref(&self) -> &QMatMul {
1950        &self.inner
1951    }
1952
1953    pub fn is_quant(&self) -> bool {
1954        matches!(self.inner, QMatMul::QTensor(_))
1955    }
1956
1957    pub fn bias(&self) -> Option<&Tensor> {
1958        self.bias.as_ref()
1959    }
1960
1961    pub fn bias_mut(&mut self) -> Option<&mut Tensor> {
1962        self.bias.as_mut()
1963    }
1964}
1965
1966impl Module for QLinear {
1967    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
1968        let xs = if self.is_quant() {
1969            xs.to_dtype(DType::F32)?
1970        } else {
1971            xs.clone()
1972        };
1973        if let Some(bias) = &self.bias {
1974            self.inner
1975                .forward(&xs)?
1976                .broadcast_add(bias)?
1977                .to_dtype(self.dtype)
1978        } else {
1979            self.inner.forward(&xs)?.to_dtype(self.dtype)
1980        }
1981    }
1982}
1983
1984#[derive(Debug, Clone)]
1985pub struct RotaryEmbedding {
1986    cos: Tensor,
1987    sin: Tensor,
1988    is_gpt_neox: bool,
1989}
1990
1991impl RotaryEmbedding {
1992    pub fn new(
1993        base: f32,
1994        head_dim: usize,
1995        max_position_embeddings: usize,
1996        device: &Device,
1997        is_gpt_neox: bool,
1998        dtype: DType,
1999    ) -> Result<Self> {
2000        let inv_freq: Vec<_> = (0..head_dim)
2001            .step_by(2)
2002            .map(|i| 1f32 / base.powf(i as f32 / head_dim as f32))
2003            .collect();
2004        let inv_freq_len = inv_freq.len();
2005        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), device)?;
2006        let t = Tensor::arange(0u32, max_position_embeddings as u32, device)?
2007            .to_dtype(DType::F32)?
2008            .reshape((max_position_embeddings, 1))?;
2009        let freqs = t.matmul(&inv_freq)?;
2010        let sin = freqs.sin()?.to_dtype(dtype)?;
2011        let cos = freqs.cos()?.to_dtype(dtype)?;
2012
2013        Ok(Self {
2014            cos,
2015            sin,
2016            is_gpt_neox,
2017        })
2018    }
2019
2020    pub fn get_cos_sin(&self) -> Result<(Tensor, Tensor)> {
2021        Ok((self.cos.clone(), self.sin.clone()))
2022    }
2023
2024    pub fn new_partial(
2025        base: f32,
2026        rot_dim: usize,
2027        max_position_embeddings: usize,
2028        device: &Device,
2029        is_gpt_neox: bool,
2030        dtype: DType,
2031    ) -> Result<Self> {
2032        let inv_freq: Vec<_> = (0..rot_dim)
2033            .step_by(2)
2034            .map(|i| 1f32 / base.powf(i as f32 / rot_dim as f32))
2035            .collect();
2036        let inv_freq_len = inv_freq.len();
2037        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), device)?;
2038        let t = Tensor::arange(0u32, max_position_embeddings as u32, device)?
2039            .to_dtype(DType::F32)?
2040            .reshape((max_position_embeddings, 1))?;
2041        let freqs = t.matmul(&inv_freq)?;
2042        let sin = freqs.sin()?.to_dtype(dtype)?;
2043        let cos = freqs.cos()?.to_dtype(dtype)?;
2044
2045        Ok(Self {
2046            cos,
2047            sin,
2048            is_gpt_neox,
2049        })
2050    }
2051
2052    pub fn forward(
2053        &self,
2054        q: &Tensor,
2055        k: &Tensor,
2056        seqlen_offsets: &[usize],
2057    ) -> Result<(Tensor, Tensor)> {
2058        let (b_sz, qh, seq_len, n_embd) = q.dims4()?;
2059        let (_b_sz, kh, _seq_len, __n_embd) = k.dims4()?;
2060
2061        let rope = if self.is_gpt_neox {
2062            candle_nn::rotary_emb::rope
2063        } else {
2064            candle_nn::rotary_emb::rope_i
2065        };
2066
2067        if cfg!(feature = "cuda") && qh == kh {
2068            let (cos, sin) = if seqlen_offsets.len() == 1 {
2069                (
2070                    self.cos.narrow(0, seqlen_offsets[0], seq_len)?,
2071                    self.sin.narrow(0, seqlen_offsets[0], seq_len)?,
2072                )
2073            } else {
2074                let mut cos_s = Vec::new();
2075                let mut sin_s = Vec::new();
2076                for offset in seqlen_offsets {
2077                    cos_s.push(self.cos.narrow(0, *offset, seq_len)?);
2078                    sin_s.push(self.sin.narrow(0, *offset, seq_len)?);
2079                }
2080                (Tensor::cat(&cos_s, 0)?, Tensor::cat(&sin_s, 0)?)
2081            };
2082
2083            let q_embed = q.transpose(1, 2)?.flatten(0, 1)?;
2084            let k_embed = k.transpose(1, 2)?.flatten(0, 1)?;
2085            mistralrs_quant::rotary::apply_rotary_inplace(
2086                &q_embed,
2087                &k_embed,
2088                &cos,
2089                &sin,
2090                self.is_gpt_neox,
2091            )?;
2092            let mut q = q_embed
2093                .reshape((b_sz, seq_len, qh, n_embd))?
2094                .transpose(1, 2)?;
2095            let mut k = k_embed
2096                .reshape((b_sz, seq_len, kh, n_embd))?
2097                .transpose(1, 2)?;
2098            if !(cfg!(feature = "flash-attn") || cfg!(feature = "flash-attn-v3")) {
2099                q = q.contiguous()?;
2100                k = k.contiguous()?;
2101            }
2102            Ok((q, k))
2103        } else if seqlen_offsets.len() == 1 {
2104            let cos = self.cos.narrow(0, seqlen_offsets[0], seq_len)?;
2105            let sin = self.sin.narrow(0, seqlen_offsets[0], seq_len)?;
2106            let q_embed = rope(&q.contiguous()?, &cos, &sin)?;
2107            let k_embed = rope(&k.contiguous()?, &cos, &sin)?;
2108            Ok((q_embed, k_embed))
2109        } else {
2110            let mut q_embeds = Vec::new();
2111            let mut k_embeds = Vec::new();
2112            for (i, offset) in seqlen_offsets.iter().enumerate() {
2113                let cos = self.cos.narrow(0, *offset, seq_len)?;
2114                let sin = self.sin.narrow(0, *offset, seq_len)?;
2115                let q_embed = rope(&q.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
2116                let k_embed = rope(&k.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
2117                q_embeds.push(q_embed);
2118                k_embeds.push(k_embed);
2119            }
2120            Ok((Tensor::cat(&q_embeds, 0)?, Tensor::cat(&k_embeds, 0)?))
2121        }
2122    }
2123}
2124
2125#[derive(Debug, Clone, Copy, PartialEq, Deserialize, Serialize, Default)]
2126#[serde(rename_all = "lowercase")]
2127pub enum Activation {
2128    #[default]
2129    #[serde(alias = "gelu")]
2130    Gelu,
2131    #[serde(alias = "gelu_new")]
2132    NewGelu,
2133    Relu,
2134    Relu2,
2135    Relu6,
2136    Silu,
2137    Sigmoid,
2138    HardSigmoid,
2139    Swiglu,
2140    Swish,
2141    HardSwish,
2142    Elu(f64),
2143    LeakyRelu(f64),
2144    #[serde(alias = "gelu_pytorch_tanh")]
2145    GeluPytorchTanh,
2146    QuickGelu,
2147}
2148
2149impl Module for Activation {
2150    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2151        match self {
2152            Self::Gelu => xs.gelu_erf(),
2153            // https://github.com/huggingface/transformers/blob/12f043eaeaabfef6f6efea411d98e6f6d3c094b7/src/transformers/activations.py#L49-L78
2154            Self::NewGelu => xs.gelu(),
2155            Self::Relu => xs.relu(),
2156            Self::Relu2 => xs.relu()?.sqr(),
2157            Self::Relu6 => xs.clamp(0f32, 6f32),
2158            Self::Silu => xs.silu(),
2159            Self::Sigmoid => candle_nn::ops::sigmoid(xs),
2160            Self::HardSigmoid => candle_nn::ops::hard_sigmoid(xs),
2161            Self::Swiglu => candle_nn::ops::swiglu(xs),
2162            Self::Swish => xs * candle_nn::ops::sigmoid(xs)?,
2163            Self::HardSwish => xs * candle_nn::ops::hard_sigmoid(xs)?,
2164            &Self::Elu(alpha) => xs.elu(alpha),
2165            &Self::LeakyRelu(negative_slope) => candle_nn::ops::leaky_relu(xs, negative_slope),
2166            Self::GeluPytorchTanh => xs.gelu(),
2167            Self::QuickGelu => xs * candle_nn::ops::sigmoid(&(xs * 1.702f64)?),
2168        }
2169    }
2170}
2171
2172impl TryInto<candle_nn::Activation> for Activation {
2173    type Error = candle_core::Error;
2174
2175    fn try_into(self) -> Result<candle_nn::Activation> {
2176        match self {
2177            Self::Gelu => Ok(candle_nn::Activation::Gelu),
2178            Self::Relu => Ok(candle_nn::Activation::Relu),
2179            Self::Silu => Ok(candle_nn::Activation::Silu),
2180            Self::NewGelu => Ok(candle_nn::Activation::NewGelu),
2181            Self::Relu2 => Ok(candle_nn::Activation::Relu2),
2182            Self::Relu6 => Ok(candle_nn::Activation::Relu6),
2183            Self::Sigmoid => Ok(candle_nn::Activation::Sigmoid),
2184            Self::HardSigmoid => Ok(candle_nn::Activation::HardSigmoid),
2185            Self::Swiglu => Ok(candle_nn::Activation::Swiglu),
2186            Self::Swish => Ok(candle_nn::Activation::Swish),
2187            Self::HardSwish => Ok(candle_nn::Activation::HardSwish),
2188            Self::Elu(x) => Ok(candle_nn::Activation::Elu(x)),
2189            Self::LeakyRelu(x) => Ok(candle_nn::Activation::LeakyRelu(x)),
2190            Self::GeluPytorchTanh => Ok(candle_nn::Activation::GeluPytorchTanh),
2191            Self::QuickGelu => candle_core::bail!("No mapping to candle_nn for QuickGelu"),
2192        }
2193    }
2194}
2195
2196#[derive(Debug, Clone, Copy, PartialEq, Eq)]
2197pub struct Conv3dConfig {
2198    pub padding: usize,
2199    pub stride: usize,
2200    pub dilation: usize,
2201    pub groups: usize,
2202}
2203
2204impl Default for Conv3dConfig {
2205    fn default() -> Self {
2206        Self {
2207            padding: 0,
2208            stride: 1,
2209            dilation: 1,
2210            groups: 1,
2211        }
2212    }
2213}
2214
2215pub struct Conv3dNoBias {
2216    conv2d_1: Conv2d,
2217    conv2d_2: Conv2d,
2218}
2219
2220impl Conv3dNoBias {
2221    pub fn new(
2222        in_channels: usize,
2223        out_channels: usize,
2224        kernel_sizes: [usize; 3],
2225        cfg: Conv3dConfig,
2226        vb: ShardedVarBuilder,
2227    ) -> Result<Self> {
2228        let ws = vb.get(
2229            (
2230                out_channels,
2231                in_channels / cfg.groups,
2232                kernel_sizes[0],
2233                kernel_sizes[1],
2234                kernel_sizes[2],
2235            ),
2236            "weight",
2237        )?;
2238
2239        // Split on temporal dimension
2240        // https://github.com/pytorch/pytorch/issues/139066
2241
2242        let w1 = ws.i((.., .., 0, .., ..))?;
2243        let w2 = ws.i((.., .., 1, .., ..))?;
2244
2245        let cfg = Conv2dConfig {
2246            padding: cfg.padding,
2247            stride: cfg.stride,
2248            dilation: cfg.dilation,
2249            groups: cfg.groups,
2250        };
2251
2252        Ok(Self {
2253            conv2d_1: Conv2d::new(w1.contiguous()?, None, cfg),
2254            conv2d_2: Conv2d::new(w2.contiguous()?, None, cfg),
2255        })
2256    }
2257}
2258
2259impl Module for Conv3dNoBias {
2260    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2261        let xs1 = xs.i((.., .., 0, .., ..))?;
2262        let xs2 = xs.i((.., .., 1, .., ..))?;
2263
2264        (self.conv2d_1.forward(&xs1)? + self.conv2d_2.forward(&xs2)?)?.unsqueeze(2)
2265    }
2266}
2267
2268pub trait TensorInfExtend {
2269    fn is_inf(&self) -> Result<Self>
2270    where
2271        Self: Sized;
2272    fn any(&self) -> Result<bool>;
2273}
2274
2275impl TensorInfExtend for Tensor {
2276    fn is_inf(&self) -> Result<Self> {
2277        self.broadcast_eq(&Tensor::new(f32::INFINITY, self.device())?.to_dtype(self.dtype())?)
2278    }
2279
2280    fn any(&self) -> Result<bool> {
2281        let sum = self.sum_all()?;
2282        match self.dtype() {
2283            DType::U8 => Ok(sum.to_scalar::<u8>()? == 0),
2284            DType::U32 => Ok(sum.to_scalar::<u32>()? == 0),
2285            DType::I16 => Ok(sum.to_scalar::<i16>()? == 0),
2286            DType::I32 => Ok(sum.to_scalar::<i32>()? == 0),
2287            DType::I64 => Ok(sum.to_scalar::<i64>()? == 0),
2288            DType::F16 => Ok(sum.to_scalar::<half::f16>()? == half::f16::from_f32_const(0.)),
2289            DType::BF16 => Ok(sum.to_scalar::<half::bf16>()? == half::bf16::from_f32_const(0.)),
2290            DType::F32 => Ok(sum.to_scalar::<f32>()? == 0.),
2291            DType::F64 => Ok(sum.to_scalar::<f64>()? == 0.),
2292            DType::F8E4M3 => Ok(sum.to_scalar::<F8E4M3>()? == F8E4M3::ZERO),
2293        }
2294    }
2295}
2296
2297pub fn clamp_for_f16(xs: &Tensor) -> Result<Tensor> {
2298    let mut max = match xs.dtype() {
2299        DType::U8 => u8::MAX as f32 - 1000.,
2300        DType::U32 => u32::MAX as f32 - 1000.,
2301        DType::I16 => i16::MAX as f32 - 1000.,
2302        DType::I32 => i32::MAX as f32 - 1000.,
2303        DType::I64 => i64::MAX as f32 - 1000.,
2304        DType::F16 => half::f16::MAX.to_f32_const() - 1000.,
2305        DType::BF16 => half::bf16::MAX.to_f32_const() - 1000.,
2306        DType::F32 => f32::MAX - 1000.,
2307        DType::F64 => f64::MAX as f32 - 1000.,
2308        DType::F8E4M3 => F8E4M3::MAX.to_f32() - 1000.,
2309    };
2310    if xs.is_inf()?.any()? {
2311        max -= 1000.;
2312    }
2313    xs.clamp(-max, max)
2314}
2315
2316pub struct FloatInfo {
2317    /// Minimum representable value.
2318    pub min: f64,
2319    /// Maximum representable value.
2320    pub max: f64,
2321    /// The difference between 1.0 and the next smallest representable float larger than 1.0.
2322    pub eps: f64,
2323    pub dtype: DType,
2324}
2325
2326pub trait GetFloatInfo {
2327    fn finfo(&self) -> Result<FloatInfo>;
2328}
2329
2330impl GetFloatInfo for DType {
2331    fn finfo(&self) -> Result<FloatInfo> {
2332        let finfo = match self {
2333            Self::BF16 => FloatInfo {
2334                min: bf16::MIN.to_f64(),
2335                max: bf16::MAX.to_f64(),
2336                eps: bf16::EPSILON.to_f64(),
2337                dtype: DType::BF16,
2338            },
2339            Self::F16 => FloatInfo {
2340                min: f16::MIN.to_f64(),
2341                max: f16::MAX.to_f64(),
2342                eps: f16::EPSILON.to_f64(),
2343                dtype: DType::F16,
2344            },
2345            Self::F32 => FloatInfo {
2346                min: f32::MIN as f64,
2347                max: f32::MAX as f64,
2348                eps: f32::EPSILON as f64,
2349                dtype: DType::F32,
2350            },
2351            Self::F64 => FloatInfo {
2352                min: f64::MIN,
2353                max: f64::MAX,
2354                eps: f64::EPSILON,
2355                dtype: DType::F64,
2356            },
2357            Self::F8E4M3 => FloatInfo {
2358                min: F8E4M3::MIN.to_f64(),
2359                max: F8E4M3::MAX.to_f64(),
2360                eps: F8E4M3::EPSILON.to_f64(),
2361                dtype: DType::F8E4M3,
2362            },
2363            other => {
2364                candle_core::bail!("Expected a float type for `GetFloatInfo`, got {other:?}");
2365            }
2366        };
2367        Ok(finfo)
2368    }
2369}
2370
2371#[derive(Clone)]
2372pub struct Mlp {
2373    pub gate: Arc<dyn QuantMethod>,
2374    pub up: Arc<dyn QuantMethod>,
2375    pub down: Arc<dyn QuantMethod>,
2376    act: Activation,
2377    params: Vec<usize>,
2378}
2379
2380impl Mlp {
2381    pub fn new(
2382        vb: ShardedVarBuilder,
2383        hidden_size: usize,
2384        intermediate_size: usize,
2385        quantization_config: &Option<QuantizedConfig>,
2386        hidden_act: Activation,
2387        comm: &Arc<mistralrs_quant::Comm>,
2388    ) -> Result<Self> {
2389        Ok(Self {
2390            gate: ColumnParallelLayer::new(
2391                hidden_size,
2392                intermediate_size,
2393                quantization_config,
2394                false,
2395                comm,
2396                vb.pp("gate_proj"),
2397            )?,
2398            up: ColumnParallelLayer::new(
2399                hidden_size,
2400                intermediate_size,
2401                quantization_config,
2402                false,
2403                comm,
2404                vb.pp("up_proj"),
2405            )?,
2406            down: RowParallelLayer::new(
2407                intermediate_size,
2408                hidden_size,
2409                quantization_config,
2410                false,
2411                comm,
2412                vb.pp("down_proj"),
2413            )?,
2414            act: hidden_act,
2415            params: vec![hidden_size, intermediate_size],
2416        })
2417    }
2418
2419    pub fn new_merged(
2420        vb: ShardedVarBuilder,
2421        hidden_size: usize,
2422        intermediate_size: usize,
2423        chunks: usize,
2424        quantization_config: &Option<QuantizedConfig>,
2425        hidden_act: Activation,
2426        comm: &Arc<mistralrs_quant::Comm>,
2427    ) -> Result<Self> {
2428        assert!(chunks == 2, "Only gate_up_proj merge is supported!");
2429        let gate_up_projs = ColumnParallelLayer::new_merged(
2430            hidden_size,
2431            intermediate_size * 2,
2432            2,
2433            quantization_config,
2434            false,
2435            comm,
2436            vb.pp("gate_up_proj"),
2437        )?;
2438
2439        Ok(Self {
2440            gate: gate_up_projs[0].to_owned(),
2441            up: gate_up_projs[1].to_owned(),
2442            down: RowParallelLayer::new(
2443                intermediate_size,
2444                hidden_size,
2445                quantization_config,
2446                false,
2447                comm,
2448                vb.pp("down_proj"),
2449            )?,
2450            act: hidden_act,
2451            params: vec![hidden_size, intermediate_size],
2452        })
2453    }
2454
2455    pub fn replicate(
2456        params: &[usize],
2457        vb: ShardedVarBuilder,
2458        act: Activation,
2459        comm: &Arc<mistralrs_quant::Comm>,
2460    ) -> Result<Self> {
2461        Self::new(vb, params[0], params[1], &None, act, comm)
2462    }
2463
2464    pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2465        let original_dtype = xs.dtype();
2466        let mut xs = xs.clone();
2467        if let Some(t) = self.gate.quantized_act_type() {
2468            xs = xs.to_dtype(t)?;
2469        }
2470        let lhs = self.gate.forward(&xs)?;
2471        let rhs = self.up.forward(&xs)?;
2472        let mut res = self.down.forward(&candle_nn::ops::mul_and_act(
2473            &lhs,
2474            &rhs,
2475            self.act.try_into()?,
2476        )?)?;
2477        if self.gate.quantized_act_type().is_some() {
2478            res = res.to_dtype(original_dtype)?;
2479        }
2480        Ok(res)
2481    }
2482}
2483
2484impl AnyMoeTrainableLayer for Mlp {}
2485
2486impl MlpLayer for Mlp {
2487    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2488        let original_dtype = xs.dtype();
2489        let mut xs = xs.clone();
2490        if let Some(t) = self.gate.quantized_act_type() {
2491            xs = xs.to_dtype(t)?;
2492        }
2493        let lhs = MatMul.qmethod_matmul(&xs, &*self.gate)?;
2494        let rhs = MatMul.qmethod_matmul(&xs, &*self.up)?;
2495        let mut res = if matches!(
2496            self.act,
2497            Activation::Gelu | Activation::Silu | Activation::Relu
2498        ) {
2499            MatMul.qmethod_matmul(
2500                &candle_nn::ops::mul_and_act(&lhs, &rhs, self.act.try_into()?)?,
2501                &*self.down,
2502            )?
2503        } else {
2504            MatMul.qmethod_matmul(&(&lhs.apply(&self.act)? * &rhs)?, &*self.down)?
2505        };
2506        if self.gate.quantized_act_type().is_some() {
2507            res = res.to_dtype(original_dtype)?;
2508        }
2509        Ok(res)
2510    }
2511    fn get_isq_layers(&mut self) -> Vec<&mut Arc<dyn QuantMethod>> {
2512        vec![&mut self.gate, &mut self.up, &mut self.down]
2513    }
2514    fn clone(&self) -> Box<dyn MlpLayer> {
2515        Box::new(Clone::clone(self))
2516    }
2517    fn get_params(&self) -> &[usize] {
2518        &self.params
2519    }
2520    fn hidden_act(&self) -> Activation {
2521        self.act
2522    }
2523    // gate, up, down
2524    fn new_added_delta(&self, deltas: Vec<Option<Tensor>>) -> Result<Box<dyn MlpLayer>> {
2525        let gate = if let Some(ref delta) = deltas[0] {
2526            self.gate.add_delta_w(delta)?
2527        } else {
2528            self.gate.clone()
2529        };
2530        let up = if let Some(ref delta) = deltas[1] {
2531            self.up.add_delta_w(delta)?
2532        } else {
2533            self.up.clone()
2534        };
2535        let down = if let Some(ref delta) = deltas[2] {
2536            self.down.add_delta_w(delta)?
2537        } else {
2538            self.down.clone()
2539        };
2540
2541        Ok(Box::new(Self {
2542            gate,
2543            up,
2544            down,
2545            act: self.act,
2546            params: self.params.clone(),
2547        }))
2548    }
2549
2550    fn dtype_device(&self) -> (DType, Device) {
2551        self.gate.dtype_and_device()
2552    }
2553}
2554
2555pub struct AvgPool2d {
2556    kernel_size: usize,
2557    stride: usize,
2558}
2559
2560impl AvgPool2d {
2561    pub fn new(kernel_size: usize, stride: usize) -> Self {
2562        Self {
2563            kernel_size,
2564            stride,
2565        }
2566    }
2567
2568    pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2569        xs.avg_pool2d_with_stride(self.kernel_size, self.stride)
2570    }
2571}
2572
2573/// Applies 2D reflection padding to a tensor of shape (N, C, H, W).
2574///
2575/// The `padding` argument is a 4-tuple (pad_left, pad_right, pad_top, pad_bottom).
2576/// For left padding, it reflects the values from column 1 up to pad_left (in reverse order);
2577/// for right padding, it reflects from the second-to-last column backwards, and similarly for
2578/// vertical (height) padding.
2579pub struct ReflectionPad2d {
2580    padding: (usize, usize, usize, usize),
2581}
2582
2583impl ReflectionPad2d {
2584    pub fn new(padding: (usize, usize, usize, usize)) -> Self {
2585        Self { padding }
2586    }
2587}
2588
2589impl Module for ReflectionPad2d {
2590    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2591        let (pad_left, pad_right, pad_top, pad_bottom) = self.padding;
2592
2593        let (_n, _c, h, w) = xs.dims4()?;
2594
2595        // --- Horizontal Padding (along width, axis = 3) ---
2596        // For left padding, we reflect columns 1..=pad_left (in reverse order).
2597        let left_pad = if pad_left > 0 {
2598            // Create indices: [pad_left, pad_left-1, ..., 1]
2599            let indices: Vec<i64> = (1..=pad_left as i64).rev().collect();
2600            Some(xs.index_select(&Tensor::new(indices, &Device::Cpu)?, 3)?)
2601        } else {
2602            None
2603        };
2604
2605        // For right padding, we reflect from the right side (excluding the last column).
2606        let right_pad = if pad_right > 0 {
2607            // For pad_right == 2, generate indices: [w-2, w-3, ... , w-1-pad_right]
2608            let start = w as i64 - 2;
2609            let indices: Vec<i64> = (0..pad_right as i64).map(|i| start - i).collect();
2610            Some(xs.index_select(&Tensor::new(indices, &Device::Cpu)?, 3)?)
2611        } else {
2612            None
2613        };
2614
2615        // Concatenate horizontally (along width, dim=3)
2616        let x_padded_width = match (left_pad, right_pad) {
2617            (Some(l), Some(r)) => Tensor::cat(&[l, xs.clone(), r], 3)?,
2618            (Some(l), None) => Tensor::cat(&[l, xs.clone()], 3)?,
2619            (None, Some(r)) => Tensor::cat(&[xs.clone(), r], 3)?,
2620            (None, None) => xs.clone(),
2621        };
2622
2623        // --- Vertical Padding (along height, axis = 2) ---
2624        // For top padding, reflect rows 1..=pad_top (in reverse order)
2625        let top_pad = if pad_top > 0 {
2626            let indices: Vec<i64> = (1..=pad_top as i64).rev().collect();
2627            Some(x_padded_width.index_select(&Tensor::new(indices, &Device::Cpu)?, 2)?)
2628        } else {
2629            None
2630        };
2631
2632        // For bottom padding, reflect from the bottom (excluding the last row)
2633        let bottom_pad = if pad_bottom > 0 {
2634            let start = h as i64 - 2;
2635            let indices: Vec<i64> = (0..pad_bottom as i64).map(|i| start - i).collect();
2636            Some(x_padded_width.index_select(&Tensor::new(indices, &Device::Cpu)?, 2)?)
2637        } else {
2638            None
2639        };
2640
2641        // Concatenate vertically (along height, dim=2)
2642        let x_padded = match (top_pad, bottom_pad) {
2643            (Some(t), Some(b)) => Tensor::cat(&[t, x_padded_width, b], 2)?,
2644            (Some(t), None) => Tensor::cat(&[t, x_padded_width], 2)?,
2645            (None, Some(b)) => Tensor::cat(&[x_padded_width, b], 2)?,
2646            (None, None) => x_padded_width,
2647        };
2648
2649        Ok(x_padded)
2650    }
2651}
2652
2653pub struct ScaledEmbedding {
2654    scale: f64,
2655    pub embedding: Tensor,
2656}
2657
2658impl ScaledEmbedding {
2659    pub fn new(scale: f64, embedding: Embedding) -> Self {
2660        Self {
2661            scale,
2662            embedding: embedding.embeddings().clone(),
2663        }
2664    }
2665
2666    pub fn embeddings(&self) -> &Tensor {
2667        &self.embedding
2668    }
2669}
2670
2671impl Module for ScaledEmbedding {
2672    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2673        let embedding = Embedding::new(self.embedding.clone(), self.embedding.dim(D::Minus1)?);
2674        xs.apply(&embedding)? * self.scale
2675    }
2676}