mistralrs_core/
layers.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{f32::consts::PI, ops::Mul, str::FromStr, sync::Arc};
4
5use candle_core::{
6    quantized::{QMatMul, QTensor},
7    Context, DType, Device, IndexOp, Result, Tensor, D,
8};
9use candle_nn::{
10    BatchNorm, BatchNormConfig, Conv1d, Conv1dConfig, Conv2d, Conv2dConfig, Embedding, GroupNorm,
11    LayerNorm, LayerNormConfig, Linear, Module,
12};
13use float8::F8E4M3;
14use half::{bf16, f16};
15use mistralrs_quant::{
16    AfqLayer, ColumnParallelLayer, Convolution, QuantMethod, QuantizedConfig, RowParallelLayer,
17    ShardedVarBuilder,
18};
19use serde::{Deserialize, Serialize};
20
21pub use crate::attention::Sdpa;
22pub use crate::layers_masker::CausalMasker;
23pub use crate::layers_utils::repeat_kv;
24use crate::{
25    amoe::{AnyMoeTrainableLayer, MlpLayer},
26    gguf::Content,
27    models::{llama, smollm3},
28    ops::SplitOp,
29    vision_models::{
30        gemma3::config::Gemma3TextConfig,
31        gemma3n::config::Gemma3nTextConfig,
32        llama4,
33        mllama::{MLlamaRopeScaling, MLlamaRopeType, MLlamaTextConfig},
34        phi4::Phi4MMConfig,
35    },
36};
37
38pub use mistralrs_quant::MatMul;
39
40pub fn embedding(
41    in_size: usize,
42    out_size: usize,
43    vb: ShardedVarBuilder,
44    config: &Option<QuantizedConfig>,
45) -> Result<Embedding> {
46    // AFQ quantized applies quantization to the embeddings.
47    let embeddings = if let Some(QuantizedConfig::Afq { .. }) = config {
48        let afq_layer =
49            AfqLayer::afq_linear_b(out_size, in_size, config.as_ref().unwrap(), false, vb)?;
50        afq_layer.dequantize_w()?
51    } else {
52        vb.get_with_hints((in_size, out_size), "weight", Default::default())?
53    };
54    Ok(Embedding::new(embeddings, out_size))
55}
56
57pub fn layer_norm<C: Into<LayerNormConfig>>(
58    size: usize,
59    config: C,
60    vb: ShardedVarBuilder,
61) -> Result<LayerNorm> {
62    let config = config.into();
63    let weight = vb.get(size, "weight")?;
64    if config.affine {
65        let bias = vb.get(size, "bias")?;
66        Ok(LayerNorm::new(weight, bias, config.eps))
67    } else {
68        Ok(LayerNorm::new_no_bias(weight, config.eps))
69    }
70}
71
72pub fn batch_norm<C: Into<BatchNormConfig>>(
73    num_features: usize,
74    config: C,
75    vb: ShardedVarBuilder,
76) -> Result<BatchNorm> {
77    let config = config.into();
78    if config.eps < 0. {
79        candle_core::bail!("batch-norm eps cannot be negative {}", config.eps)
80    }
81    let running_mean = vb.get(num_features, "running_mean")?;
82    let running_var = vb.get(num_features, "running_var")?;
83
84    if config.affine {
85        let weight = vb.get(num_features, "weight")?;
86        let bias = vb.get(num_features, "bias")?;
87        BatchNorm::new(
88            num_features,
89            running_mean,
90            running_var,
91            weight,
92            bias,
93            config.eps,
94        )
95    } else {
96        BatchNorm::new_no_bias(num_features, running_mean, running_var, config.eps)
97    }
98}
99
100pub fn group_norm(
101    num_groups: usize,
102    num_channels: usize,
103    eps: f64,
104    vb: ShardedVarBuilder,
105) -> Result<GroupNorm> {
106    let weight = vb.get(num_channels, "weight")?;
107    let bias = vb.get(num_channels, "bias")?;
108    GroupNorm::new(weight, bias, num_channels, num_groups, eps)
109}
110
111pub fn conv2d(
112    in_channels: usize,
113    out_channels: usize,
114    kernel_size: usize,
115    cfg: Conv2dConfig,
116    vb: ShardedVarBuilder,
117) -> Result<Conv2d> {
118    let ws = vb.get(
119        (
120            out_channels,
121            in_channels / cfg.groups,
122            kernel_size,
123            kernel_size,
124        ),
125        "weight",
126    )?;
127    let bs = vb.get(out_channels, "bias")?;
128    Ok(Conv2d::new(ws, Some(bs), cfg))
129}
130
131pub fn conv2d_no_bias(
132    in_channels: usize,
133    out_channels: usize,
134    kernel_size: usize,
135    cfg: Conv2dConfig,
136    vb: ShardedVarBuilder,
137) -> Result<Conv2d> {
138    let ws = vb.get(
139        (
140            out_channels,
141            in_channels / cfg.groups,
142            kernel_size,
143            kernel_size,
144        ),
145        "weight",
146    )?;
147    Ok(Conv2d::new(ws, None, cfg))
148}
149
150pub fn conv1d(
151    in_channels: usize,
152    out_channels: usize,
153    kernel_size: usize,
154    cfg: Conv1dConfig,
155    vb: ShardedVarBuilder,
156) -> Result<Conv1d> {
157    let ws = vb.get(
158        (out_channels, in_channels / cfg.groups, kernel_size),
159        "weight",
160    )?;
161    let bs = vb.get(out_channels, "bias")?;
162    Ok(Conv1d::new(ws, Some(bs), cfg))
163}
164
165pub fn conv1d_no_bias(
166    in_channels: usize,
167    out_channels: usize,
168    kernel_size: usize,
169    cfg: Conv1dConfig,
170    vb: ShardedVarBuilder,
171) -> Result<Conv1d> {
172    let ws = vb.get(
173        (out_channels, in_channels / cfg.groups, kernel_size),
174        "weight",
175    )?;
176    Ok(Conv1d::new(ws, None, cfg))
177}
178
179pub fn linear(in_dim: usize, out_dim: usize, vb: ShardedVarBuilder) -> Result<Linear> {
180    let ws = vb.get((out_dim, in_dim), "weight")?;
181    let bs = vb.get(out_dim, "bias")?;
182    Ok(Linear::new(ws, Some(bs)))
183}
184
185pub fn linear_no_bias(in_dim: usize, out_dim: usize, vb: ShardedVarBuilder) -> Result<Linear> {
186    let ws = vb.get((out_dim, in_dim), "weight")?;
187    Ok(Linear::new(ws, None))
188}
189
190pub fn linear_b(
191    in_dim: usize,
192    out_dim: usize,
193    bias: bool,
194    vb: ShardedVarBuilder,
195) -> Result<Linear> {
196    if bias {
197        linear(in_dim, out_dim, vb)
198    } else {
199        linear_no_bias(in_dim, out_dim, vb)
200    }
201}
202
203#[derive(Debug, Clone)]
204pub struct RmsNorm {
205    eps: f64,
206    weight: Tensor,
207}
208
209impl RmsNorm {
210    pub fn new(size: usize, eps: f64, vb: ShardedVarBuilder) -> Result<Self> {
211        let w = vb.get(size, "weight")?;
212        Ok(Self { eps, weight: w })
213    }
214
215    /// Gemma uses weight + 1.0
216    pub fn new_gemma(size: usize, eps: f64, vb: ShardedVarBuilder) -> Result<Self> {
217        let w = vb.get(size, "weight")?;
218        let w = (w + 1.0)?;
219        Ok(Self { eps, weight: w })
220    }
221
222    /// Gemma 3n uses weight
223    pub fn new_gemma_3n(
224        size: usize,
225        eps: f64,
226        with_scale: bool,
227        vb: ShardedVarBuilder,
228    ) -> Result<Self> {
229        let w = if with_scale {
230            vb.get(size, "weight")?
231        } else {
232            Tensor::ones(size, vb.dtype(), vb.device())?
233        };
234        Ok(Self { eps, weight: w })
235    }
236
237    /// Gemma uses weight + 1.0. Undo for UQFF generation.
238    pub fn undo_gemma(&self) -> Result<Self> {
239        Ok(Self {
240            eps: self.eps,
241            weight: (&self.weight - 1.0)?,
242        })
243    }
244
245    pub fn from_w(w: Tensor, eps: f64) -> Result<Self> {
246        Ok(Self { eps, weight: w })
247    }
248
249    pub fn weight(&self) -> &Tensor {
250        &self.weight
251    }
252}
253
254impl Module for RmsNorm {
255    fn forward(&self, x: &Tensor) -> Result<Tensor> {
256        candle_nn::ops::rms_norm(&x.contiguous()?, &self.weight, self.eps as f32)
257    }
258}
259
260#[derive(Debug, Clone)]
261pub struct F32RmsNorm {
262    w: Tensor,
263    eps: f64,
264}
265
266impl F32RmsNorm {
267    pub fn new(size: usize, eps: f64, vb: ShardedVarBuilder) -> Result<Self> {
268        Ok(Self {
269            w: vb.get((size,), "weight")?,
270            eps,
271        })
272    }
273
274    pub fn weight(&self) -> &Tensor {
275        &self.w
276    }
277}
278
279impl Module for F32RmsNorm {
280    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
281        let initial_type = xs.dtype();
282        let mut xs = xs.to_dtype(DType::F32)?;
283        let var = xs.powf(2.)?.mean_keepdim(D::Minus1)?;
284        xs = xs.broadcast_mul(&(&var + self.eps)?.recip()?.sqrt()?)?;
285        xs.to_dtype(initial_type)?.broadcast_mul(&self.w)
286    }
287}
288
289#[derive(Debug, Clone)]
290pub struct QRmsNorm {
291    eps: f64,
292    weight: Tensor,
293}
294
295impl QRmsNorm {
296    pub fn new(scale: QTensor, eps: f32) -> Result<Self> {
297        let scale = scale.dequantize(&scale.device())?;
298        Ok(Self {
299            eps: eps as f64,
300            weight: scale,
301        })
302    }
303
304    pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
305        candle_nn::ops::rms_norm(&x.contiguous()?, &self.weight, self.eps as f32)
306    }
307}
308
309/// RoPE supporting LongRope
310#[derive(Debug, Clone)]
311pub struct PhiRotaryEmbedding {
312    short_sin: Tensor,
313    short_cos: Tensor,
314    long_cos: Option<Tensor>,
315    long_sin: Option<Tensor>,
316    original_max_position_embeddings: usize,
317}
318
319#[derive(Debug, Clone, Deserialize, Serialize)]
320#[serde(rename_all = "lowercase")]
321pub enum ScaledRopeType {
322    #[serde(alias = "su")]
323    #[serde(alias = "longrope")]
324    Su,
325    #[serde(alias = "yarn")]
326    Yarn,
327    #[serde(alias = "dynamic")]
328    Dynamic,
329    #[serde(alias = "linear")]
330    Linear,
331}
332
333impl FromStr for ScaledRopeType {
334    type Err = candle_core::Error;
335    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
336        match s {
337            "su" | "longrope" => Ok(Self::Su),
338            "yarn" => Ok(Self::Yarn),
339            "linear" => Ok(Self::Linear),
340            "dynamic" => Ok(Self::Dynamic),
341            _ => Err(candle_core::Error::Msg(
342                "Expected either `su` or `yarn` scaled RoPE type.".to_string(),
343            )),
344        }
345    }
346}
347
348#[derive(Debug, Clone, Deserialize, Serialize)]
349#[serde(untagged)]
350pub enum PhiRopeScalingConfig {
351    Classic {
352        short_factor: Vec<f64>,
353        long_factor: Vec<f64>,
354        #[serde(rename = "type")]
355        scaling_type: ScaledRopeType,
356    },
357    Scaled {
358        short_factor: Vec<f64>,
359        long_factor: Vec<f64>,
360        #[serde(rename = "type")]
361        scaling_type: ScaledRopeType,
362        long_mscale: f64,
363        short_mscale: f64,
364    },
365}
366
367pub struct PhiRopeConfig {
368    pub rope_scaling: Option<PhiRopeScalingConfig>,
369    pub max_position_embeddings: usize,
370    pub original_max_position_embeddings: usize,
371    pub rope_theta: f64,
372    pub head_dim: usize,
373    pub partial_rotary_factor: Option<f64>,
374}
375
376impl PhiRotaryEmbedding {
377    fn new_classic_scaled(
378        short_factor: &[f64],
379        long_factor: &[f64],
380        scaling_type: &ScaledRopeType,
381        cfg: &PhiRopeConfig,
382        dtype: DType,
383        dev: &Device,
384    ) -> Result<Self> {
385        let max_seq_len = cfg.max_position_embeddings;
386        let dim = (cfg.head_dim as f64 * cfg.partial_rotary_factor.unwrap_or(1.)) as usize;
387
388        // Calculate scale
389        let scale =
390            cfg.max_position_embeddings as f64 / cfg.original_max_position_embeddings as f64;
391        let scaling_factor = if scale <= 1.0 {
392            1.0
393        } else {
394            match scaling_type {
395                ScaledRopeType::Su => {
396                    (1.0 + scale.ln() / (cfg.original_max_position_embeddings as f64).ln()).sqrt()
397                }
398                ScaledRopeType::Yarn => 0.1 * scale.ln() + 1.0,
399                _ => candle_core::bail!("Expected either `su` or `yarn` RoPE"),
400            }
401        };
402
403        // Calculate inv freqs for short, long
404        let inv_freq_long = (0..dim)
405            .step_by(2)
406            .enumerate()
407            .map(|(k, i)| {
408                (1f64 / (long_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64))) as f32
409            })
410            .collect::<Vec<_>>();
411        let inv_freq_short = (0..dim)
412            .step_by(2)
413            .enumerate()
414            .map(|(k, i)| {
415                (1f64 / (short_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64))) as f32
416            })
417            .collect::<Vec<_>>();
418        let inv_freq_len = inv_freq_long.len();
419
420        let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
421            .to_dtype(DType::F32)?
422            .reshape((max_seq_len, 1))?;
423
424        // Calculate sin,cos for long
425        let inv_freq_long = Tensor::from_vec(inv_freq_long, (1, inv_freq_len), dev)?;
426        let freqs_long = t.matmul(&inv_freq_long)?;
427        let long_sin = freqs_long.sin()?.mul(scaling_factor)?.to_dtype(dtype)?;
428        let long_cos = freqs_long.cos()?.mul(scaling_factor)?.to_dtype(dtype)?;
429
430        // Calculate sin,cos for short
431        let inv_freq_short =
432            Tensor::from_vec(inv_freq_short, (1, inv_freq_len), dev)?.to_dtype(DType::F32)?;
433        let freqs_short = t.matmul(&inv_freq_short)?;
434        let short_sin = freqs_short.sin()?.mul(scaling_factor)?.to_dtype(dtype)?;
435        let short_cos = freqs_short.cos()?.mul(scaling_factor)?.to_dtype(dtype)?;
436
437        Ok(Self {
438            short_cos,
439            short_sin,
440            long_cos: Some(long_cos),
441            long_sin: Some(long_sin),
442            original_max_position_embeddings: cfg.original_max_position_embeddings,
443        })
444    }
445
446    fn new_unscaled(cfg: &PhiRopeConfig, dtype: DType, dev: &Device) -> Result<Self> {
447        let max_seq_len = cfg.max_position_embeddings;
448        let dim = (cfg.head_dim as f64 * cfg.partial_rotary_factor.unwrap_or(1.)) as usize;
449
450        let inv_freq: Vec<_> = (0..dim)
451            .step_by(2)
452            .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
453            .collect();
454        let inv_freq_len = inv_freq.len();
455        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
456        let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
457            .to_dtype(DType::F32)?
458            .reshape((max_seq_len, 1))?;
459        let freqs = t.matmul(&inv_freq)?;
460        let sin = freqs.sin()?.to_dtype(dtype)?;
461        let cos = freqs.cos()?.to_dtype(dtype)?;
462        Ok(Self {
463            short_cos: cos,
464            short_sin: sin,
465            long_cos: None,
466            long_sin: None,
467            original_max_position_embeddings: cfg.original_max_position_embeddings,
468        })
469    }
470
471    #[allow(clippy::too_many_arguments)]
472    fn new_scaled(
473        short_factor: &[f64],
474        long_factor: &[f64],
475        scaling_type: &ScaledRopeType,
476        long_mscale: f64,
477        short_mscale: f64,
478        cfg: &PhiRopeConfig,
479        dtype: DType,
480        dev: &Device,
481    ) -> Result<Self> {
482        let max_seq_len = cfg.max_position_embeddings;
483        let dim = (cfg.head_dim as f64 * cfg.partial_rotary_factor.unwrap_or(1.)) as usize;
484
485        if !matches!(scaling_type, ScaledRopeType::Su) {
486            candle_core::bail!("Scaled Phi3 RoPE (non-classic scaled, with mscales) must have type `su`/`longrope`.");
487        }
488
489        if short_factor.len() != dim / 2 {
490            candle_core::bail!(
491                "Misaligned length {}, expected {} for `su`/`longrope` short rescale factors",
492                short_factor.len(),
493                dim / 2
494            );
495        }
496        if long_factor.len() != dim / 2 {
497            candle_core::bail!(
498                "Misaligned length {}, expected {} for `su`/`longrope` long rescale factors",
499                long_factor.len(),
500                dim / 2
501            );
502        }
503
504        // Short cos/sin
505        let inv_freq_short: Vec<_> = (0..dim)
506            .step_by(2)
507            .enumerate()
508            .map(|(k, i)| {
509                1f32 / (short_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32
510            })
511            .collect();
512        let inv_freq_len_short = inv_freq_short.len();
513        let inv_freq_short = Tensor::from_vec(inv_freq_short, (1, inv_freq_len_short), dev)?;
514        let t_short = Tensor::arange(0u32, max_seq_len as u32, dev)?
515            .to_dtype(DType::F32)?
516            .reshape((max_seq_len, 1))?;
517        let freqs_short = t_short.matmul(&inv_freq_short)?;
518        let sin_short = (freqs_short.sin()?.to_dtype(dtype)? * short_mscale)?;
519        let cos_short = (freqs_short.cos()?.to_dtype(dtype)? * short_mscale)?;
520
521        // Long cos/sin
522        let inv_freq_long: Vec<_> = (0..dim)
523            .step_by(2)
524            .enumerate()
525            .map(|(k, i)| {
526                1f32 / (long_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32
527            })
528            .collect();
529        let inv_freq_len_long = inv_freq_long.len();
530        let inv_freq_long = Tensor::from_vec(inv_freq_long, (1, inv_freq_len_long), dev)?;
531        let t_long = Tensor::arange(0u32, max_seq_len as u32, dev)?
532            .to_dtype(DType::F32)?
533            .reshape((max_seq_len, 1))?;
534        let freqs_long = t_long.matmul(&inv_freq_long)?;
535        let sin_long = (freqs_long.sin()?.to_dtype(dtype)? * long_mscale)?;
536        let cos_long = (freqs_long.cos()?.to_dtype(dtype)? * long_mscale)?;
537        Ok(Self {
538            short_cos: cos_short,
539            short_sin: sin_short,
540            long_cos: Some(cos_long),
541            long_sin: Some(sin_long),
542            original_max_position_embeddings: cfg.original_max_position_embeddings,
543        })
544    }
545
546    pub fn new(dtype: DType, cfg: impl Into<PhiRopeConfig>, dev: &Device) -> Result<Self> {
547        let cfg: PhiRopeConfig = cfg.into();
548
549        match &cfg.rope_scaling {
550            Some(PhiRopeScalingConfig::Classic {
551                short_factor,
552                long_factor,
553                scaling_type,
554            }) => {
555                Self::new_classic_scaled(short_factor, long_factor, scaling_type, &cfg, dtype, dev)
556            }
557
558            Some(PhiRopeScalingConfig::Scaled {
559                short_factor,
560                long_factor,
561                scaling_type,
562                long_mscale,
563                short_mscale,
564            }) => Self::new_scaled(
565                short_factor,
566                long_factor,
567                scaling_type,
568                *long_mscale,
569                *short_mscale,
570                &cfg,
571                dtype,
572                dev,
573            ),
574
575            None => Self::new_unscaled(&cfg, dtype, dev),
576        }
577    }
578
579    /// Returns (sin, cos) taking into account LongRope
580    fn get_long_or_short_sin_cos(&self, position_ids: &[usize]) -> (&Tensor, &Tensor) {
581        if self.long_cos.is_none() {
582            return (&self.short_sin, &self.short_cos);
583        }
584        let seq_len = position_ids.iter().max().unwrap() + 1;
585        if seq_len > self.original_max_position_embeddings {
586            (
587                self.long_sin.as_ref().unwrap(),
588                self.long_cos.as_ref().unwrap(),
589            )
590        } else {
591            (&self.short_sin, &self.short_cos)
592        }
593    }
594
595    pub fn forward(
596        &self,
597        q: &Tensor,
598        k: &Tensor,
599        seqlen_offsets: &[usize],
600        position_ids: &[usize],
601    ) -> Result<(Tensor, Tensor)> {
602        let (sin, cos) = self.get_long_or_short_sin_cos(position_ids);
603        let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
604
605        let rot_dim = cos.dim(D::Minus1)? * 2;
606
607        // Case for Phi 3 / Phi 4 mini
608        if rot_dim != q.dim(D::Minus1)? {
609            let rot_dim = cos.dim(D::Minus1)? * 2;
610            let q_rot = q.narrow(D::Minus1, 0, rot_dim)?;
611            let q_pass = q.narrow(D::Minus1, rot_dim, q.dim(D::Minus1)? - rot_dim)?;
612            let k_rot = k.narrow(D::Minus1, 0, rot_dim)?;
613            let k_pass = k.narrow(D::Minus1, rot_dim, k.dim(D::Minus1)? - rot_dim)?;
614
615            let (q_rot, k_rot) = if seqlen_offsets.len() == 1 {
616                let cos = cos.narrow(0, seqlen_offsets[0], seq_len)?;
617                let sin = sin.narrow(0, seqlen_offsets[0], seq_len)?;
618                let q_embed = candle_nn::rotary_emb::rope(&q_rot.contiguous()?, &cos, &sin)?;
619                let k_embed = candle_nn::rotary_emb::rope(&k_rot.contiguous()?, &cos, &sin)?;
620                (q_embed, k_embed)
621            } else {
622                let mut q_embeds = Vec::new();
623                let mut k_embeds = Vec::new();
624                for (i, offset) in seqlen_offsets.iter().enumerate() {
625                    let cos = cos.narrow(0, *offset, seq_len)?;
626                    let sin = sin.narrow(0, *offset, seq_len)?;
627                    let q_embed = candle_nn::rotary_emb::rope(
628                        &q_rot.i(i)?.unsqueeze(0)?.contiguous()?,
629                        &cos,
630                        &sin,
631                    )?;
632                    let k_embed = candle_nn::rotary_emb::rope(
633                        &k_rot.i(i)?.unsqueeze(0)?.contiguous()?,
634                        &cos,
635                        &sin,
636                    )?;
637                    q_embeds.push(q_embed);
638                    k_embeds.push(k_embed);
639                }
640                let q_rot = Tensor::cat(&q_embeds, 0)?;
641                let k_rot = Tensor::cat(&k_embeds, 0)?;
642                (q_rot, k_rot)
643            };
644
645            Ok((
646                Tensor::cat(&[q_rot, q_pass], D::Minus1)?.contiguous()?,
647                Tensor::cat(&[k_rot, k_pass], D::Minus1)?.contiguous()?,
648            ))
649        } else if seqlen_offsets.len() == 1 {
650            let cos = cos.narrow(0, seqlen_offsets[0], seq_len)?;
651            let sin = sin.narrow(0, seqlen_offsets[0], seq_len)?;
652            let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
653            let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
654            Ok((q_embed, k_embed))
655        } else {
656            let mut q_embeds = Vec::new();
657            let mut k_embeds = Vec::new();
658            for (i, offset) in seqlen_offsets.iter().enumerate() {
659                let cos = cos.narrow(0, *offset, seq_len)?;
660                let sin = sin.narrow(0, *offset, seq_len)?;
661                let q_embed =
662                    candle_nn::rotary_emb::rope(&q.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
663                let k_embed =
664                    candle_nn::rotary_emb::rope(&k.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
665                q_embeds.push(q_embed);
666                k_embeds.push(k_embed);
667            }
668            Ok((Tensor::cat(&q_embeds, 0)?, Tensor::cat(&k_embeds, 0)?))
669        }
670    }
671}
672
673/// RoPE for Llama3
674#[derive(Debug, Clone)]
675pub struct Llama3RotaryEmbedding(RotaryEmbedding);
676
677#[derive(Debug, Clone, Deserialize, Serialize, Default)]
678pub enum Llama3RopeType {
679    #[serde(rename = "llama3")]
680    Llama3,
681    #[serde(rename = "linear")]
682    Linear,
683    #[default]
684    #[serde(rename = "default")]
685    Default,
686}
687
688#[derive(Debug, Clone, Deserialize, Serialize, Default)]
689pub struct Llama3RopeConfig {
690    pub factor: f32,
691    pub low_freq_factor: Option<f32>,
692    pub high_freq_factor: Option<f32>,
693    pub original_max_position_embeddings: Option<usize>,
694    pub rope_type: Llama3RopeType,
695}
696
697fn calculate_default_inv_freq(cfg: &llama::Config) -> Vec<f32> {
698    let head_dim = cfg.hidden_size / cfg.num_attention_heads;
699    (0..head_dim)
700        .step_by(2)
701        .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
702        .collect()
703}
704
705fn calculate_default_inv_freq_llama4(cfg: &llama4::TextConfig) -> Vec<f32> {
706    let head_dim = cfg.hidden_size / cfg.num_attention_heads;
707    (0..head_dim)
708        .step_by(2)
709        .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
710        .collect()
711}
712
713// https://github.com/huggingface/transformers/blob/1392a6867f40a55dfabaf306745c67627598b1af/src/transformers/modeling_rope_utils.py#L298
714impl Llama3RotaryEmbedding {
715    pub fn new_llama3(
716        dtype: DType,
717        cfg: &llama::Config,
718        dev: &Device,
719        is_gpt_neox: bool,
720    ) -> Result<Self> {
721        match &cfg.rope_scaling {
722            None
723            | Some(Llama3RopeConfig {
724                rope_type: Llama3RopeType::Default,
725                ..
726            }) => Ok(Self(RotaryEmbedding::new(
727                cfg.rope_theta,
728                cfg.hidden_size / cfg.num_attention_heads,
729                cfg.max_position_embeddings,
730                dev,
731                is_gpt_neox,
732                dtype,
733            )?)),
734            Some(Llama3RopeConfig {
735                rope_type: Llama3RopeType::Llama3,
736                factor,
737                low_freq_factor,
738                high_freq_factor,
739                original_max_position_embeddings,
740            }) => {
741                let low_freq_factor = low_freq_factor.context("low_freq_factor is required")?;
742                let high_freq_factor = high_freq_factor.context("high_freq_factor is required")?;
743                let original_max_position_embeddings = original_max_position_embeddings
744                    .context("original_max_position_embeddings is required")?;
745
746                let low_freq_wavelen = original_max_position_embeddings as f32 / low_freq_factor;
747                let high_freq_wavelen = original_max_position_embeddings as f32 / high_freq_factor;
748
749                let inv_freq = calculate_default_inv_freq(cfg)
750                    .into_iter()
751                    .map(|freq| {
752                        let wavelen = 2. * PI / freq;
753                        if wavelen < high_freq_wavelen {
754                            freq
755                        } else if wavelen > low_freq_wavelen {
756                            freq / *factor
757                        } else {
758                            let smooth = (original_max_position_embeddings as f32 / wavelen
759                                - low_freq_factor)
760                                / (high_freq_factor - low_freq_factor);
761                            (1. - smooth) * freq / *factor + smooth * freq
762                        }
763                    })
764                    .collect::<Vec<_>>();
765                let inv_freq_len = inv_freq.len();
766                let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
767                let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
768                    .to_dtype(DType::F32)?
769                    .reshape((cfg.max_position_embeddings, 1))?;
770                let freqs = t.matmul(&inv_freq)?;
771                let sin = freqs.sin()?.to_dtype(dtype)?;
772                let cos = freqs.cos()?.to_dtype(dtype)?;
773                Ok(Self(RotaryEmbedding {
774                    sin,
775                    cos,
776                    is_gpt_neox,
777                }))
778            }
779            Some(Llama3RopeConfig {
780                rope_type: Llama3RopeType::Linear,
781                factor,
782                ..
783            }) => {
784                let inv_freq_vec = calculate_default_inv_freq(cfg)
785                    .into_iter()
786                    .map(|freq| freq / *factor)
787                    .collect::<Vec<_>>();
788                let inv_freq_len = inv_freq_vec.len();
789                let inv_freq = Tensor::from_vec(inv_freq_vec, (1, inv_freq_len), dev)?;
790                let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
791                    .to_dtype(DType::F32)?
792                    .reshape((cfg.max_position_embeddings, 1))?;
793                let freqs = t.matmul(&inv_freq)?;
794                let sin = freqs.sin()?.to_dtype(dtype)?;
795                let cos = freqs.cos()?.to_dtype(dtype)?;
796                Ok(Self(RotaryEmbedding {
797                    sin,
798                    cos,
799                    is_gpt_neox,
800                }))
801            }
802        }
803    }
804
805    pub fn new_llama4(
806        dtype: DType,
807        cfg: &llama4::TextConfig,
808        dev: &Device,
809        is_gpt_neox: bool,
810    ) -> Result<Self> {
811        match &cfg.rope_scaling {
812            None
813            | Some(Llama3RopeConfig {
814                rope_type: Llama3RopeType::Default,
815                ..
816            }) => Ok(Self(RotaryEmbedding::new(
817                cfg.rope_theta,
818                cfg.hidden_size / cfg.num_attention_heads,
819                cfg.max_position_embeddings,
820                dev,
821                is_gpt_neox,
822                dtype,
823            )?)),
824            Some(Llama3RopeConfig {
825                rope_type: Llama3RopeType::Llama3,
826                factor,
827                low_freq_factor,
828                high_freq_factor,
829                original_max_position_embeddings,
830            }) => {
831                let low_freq_factor = low_freq_factor.context("low_freq_factor is required")?;
832                let high_freq_factor = high_freq_factor.context("high_freq_factor is required")?;
833                let original_max_position_embeddings = original_max_position_embeddings
834                    .context("original_max_position_embeddings is required")?;
835
836                let low_freq_wavelen = original_max_position_embeddings as f32 / low_freq_factor;
837                let high_freq_wavelen = original_max_position_embeddings as f32 / high_freq_factor;
838
839                let inv_freq = calculate_default_inv_freq_llama4(cfg)
840                    .into_iter()
841                    .map(|freq| {
842                        let wavelen = 2. * PI / freq;
843                        if wavelen < high_freq_wavelen {
844                            freq
845                        } else if wavelen > low_freq_wavelen {
846                            freq / *factor
847                        } else {
848                            let smooth = (original_max_position_embeddings as f32 / wavelen
849                                - low_freq_factor)
850                                / (high_freq_factor - low_freq_factor);
851                            (1. - smooth) * freq / *factor + smooth * freq
852                        }
853                    })
854                    .collect::<Vec<_>>();
855                let inv_freq_len = inv_freq.len();
856                let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
857                let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
858                    .to_dtype(DType::F32)?
859                    .reshape((cfg.max_position_embeddings, 1))?;
860                let freqs = t.matmul(&inv_freq)?;
861                let sin = freqs.sin()?.to_dtype(dtype)?;
862                let cos = freqs.cos()?.to_dtype(dtype)?;
863                Ok(Self(RotaryEmbedding {
864                    sin,
865                    cos,
866                    is_gpt_neox,
867                }))
868            }
869            Some(Llama3RopeConfig {
870                rope_type: Llama3RopeType::Linear,
871                factor,
872                ..
873            }) => {
874                let inv_freq_vec = calculate_default_inv_freq_llama4(cfg)
875                    .into_iter()
876                    .map(|freq| freq / *factor)
877                    .collect::<Vec<_>>();
878                let inv_freq_len = inv_freq_vec.len();
879                let inv_freq = Tensor::from_vec(inv_freq_vec, (1, inv_freq_len), dev)?;
880                let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
881                    .to_dtype(DType::F32)?
882                    .reshape((cfg.max_position_embeddings, 1))?;
883                let freqs = t.matmul(&inv_freq)?;
884                let sin = freqs.sin()?.to_dtype(dtype)?;
885                let cos = freqs.cos()?.to_dtype(dtype)?;
886                Ok(Self(RotaryEmbedding {
887                    sin,
888                    cos,
889                    is_gpt_neox,
890                }))
891            }
892        }
893    }
894
895    pub fn new_mllama3(
896        dtype: DType,
897        cfg: &MLlamaTextConfig,
898        dev: &Device,
899        is_gpt_neox: bool,
900    ) -> Result<Self> {
901        match &cfg.rope_scaling {
902            None
903            | Some(MLlamaRopeScaling {
904                rope_type: MLlamaRopeType::Default,
905                ..
906            }) => Ok(Self(RotaryEmbedding::new(
907                cfg.rope_theta,
908                cfg.hidden_size / cfg.num_attention_heads,
909                cfg.max_position_embeddings,
910                dev,
911                is_gpt_neox,
912                dtype,
913            )?)),
914            Some(MLlamaRopeScaling {
915                rope_type: MLlamaRopeType::Llama3,
916                original_max_position_embeddings,
917                factor,
918                attention_factor: _,
919                beta_fast: _,
920                beta_slow: _,
921                short_factor: _,
922                long_factor: _,
923                low_freq_factor,
924                high_freq_factor,
925            }) => {
926                let factor = factor.context("MLlama Llama3 RoPE needs `factor` parameter.")?;
927                let low_freq_factor = low_freq_factor
928                    .context("MLlama Llama3 RoPE needs `low_freq_factor` parameter.")?;
929                let high_freq_factor = high_freq_factor
930                    .context("MLlama Llama3 RoPE needs `high_freq_factor` parameter.")?;
931
932                let low_freq_wavelen = *original_max_position_embeddings as f32 / low_freq_factor;
933                let high_freq_wavelen = *original_max_position_embeddings as f32 / high_freq_factor;
934
935                let head_dim = cfg.hidden_size / cfg.num_attention_heads;
936
937                let inv_freq = (0..head_dim)
938                    .step_by(2)
939                    .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
940                    .map(|freq| {
941                        let wavelen = 2. * PI / freq;
942                        if wavelen < high_freq_wavelen {
943                            freq
944                        } else if wavelen > low_freq_wavelen {
945                            freq / factor
946                        } else {
947                            let smooth = (*original_max_position_embeddings as f32 / wavelen
948                                - low_freq_factor)
949                                / (high_freq_factor - low_freq_factor);
950                            (1. - smooth) * freq / factor + smooth * freq
951                        }
952                    })
953                    .collect::<Vec<_>>();
954                let inv_freq_len = inv_freq.len();
955                let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
956
957                let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
958                    .to_dtype(DType::F32)?
959                    .reshape((cfg.max_position_embeddings, 1))?;
960                let freqs = t.matmul(&inv_freq)?;
961                let sin = freqs.sin()?.to_dtype(dtype)?;
962                let cos = freqs.cos()?.to_dtype(dtype)?;
963                Ok(Self(RotaryEmbedding {
964                    sin,
965                    cos,
966                    is_gpt_neox,
967                }))
968            }
969            Some(MLlamaRopeScaling {
970                rope_type: other, ..
971            }) => {
972                candle_core::bail!(
973                    "MLlama doesn't support any other RoPE type than `llama3`, got {other:?}"
974                )
975            }
976        }
977    }
978
979    pub fn forward(
980        &self,
981        q: &Tensor,
982        k: &Tensor,
983        seqlen_offsets: &[usize],
984    ) -> Result<(Tensor, Tensor)> {
985        self.0.forward(q, k, seqlen_offsets)
986    }
987}
988
989/// RoPE for SmolLm3
990#[derive(Debug, Clone)]
991pub struct SmolLm3RotaryEmbedding(RotaryEmbedding);
992
993#[derive(Debug, Clone, Deserialize, Serialize, Default)]
994pub enum SmolLm3RopeType {
995    #[serde(rename = "llama3")]
996    Llama3,
997    #[serde(rename = "linear")]
998    Linear,
999    #[default]
1000    #[serde(rename = "default")]
1001    Default,
1002}
1003
1004#[derive(Debug, Clone, Deserialize, Serialize, Default)]
1005pub struct SmolLm3RopeConfig {
1006    pub factor: f32,
1007    pub low_freq_factor: Option<f32>,
1008    pub high_freq_factor: Option<f32>,
1009    pub original_max_position_embeddings: Option<usize>,
1010    pub rope_type: SmolLm3RopeType,
1011}
1012
1013fn calculate_default_inv_freq_smollm3(cfg: &smollm3::Config) -> Vec<f32> {
1014    let head_dim = cfg.hidden_size / cfg.num_attention_heads;
1015    (0..head_dim)
1016        .step_by(2)
1017        .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
1018        .collect()
1019}
1020
1021impl SmolLm3RotaryEmbedding {
1022    pub fn new_llama3(
1023        dtype: DType,
1024        cfg: &smollm3::Config,
1025        dev: &Device,
1026        is_gpt_neox: bool,
1027    ) -> Result<Self> {
1028        match &cfg.rope_scaling {
1029            None
1030            | Some(SmolLm3RopeConfig {
1031                rope_type: SmolLm3RopeType::Default,
1032                ..
1033            }) => Ok(Self(RotaryEmbedding::new(
1034                cfg.rope_theta,
1035                cfg.hidden_size / cfg.num_attention_heads,
1036                cfg.max_position_embeddings,
1037                dev,
1038                is_gpt_neox,
1039                dtype,
1040            )?)),
1041            Some(SmolLm3RopeConfig {
1042                rope_type: SmolLm3RopeType::Llama3,
1043                factor,
1044                low_freq_factor,
1045                high_freq_factor,
1046                original_max_position_embeddings,
1047            }) => {
1048                let low_freq_factor = low_freq_factor.context("low_freq_factor is required")?;
1049                let high_freq_factor = high_freq_factor.context("high_freq_factor is required")?;
1050                let original_max_position_embeddings = original_max_position_embeddings
1051                    .context("original_max_position_embeddings is required")?;
1052
1053                let low_freq_wavelen = original_max_position_embeddings as f32 / low_freq_factor;
1054                let high_freq_wavelen = original_max_position_embeddings as f32 / high_freq_factor;
1055
1056                let inv_freq = calculate_default_inv_freq_smollm3(cfg)
1057                    .into_iter()
1058                    .map(|freq| {
1059                        let wavelen = 2. * PI / freq;
1060                        if wavelen < high_freq_wavelen {
1061                            freq
1062                        } else if wavelen > low_freq_wavelen {
1063                            freq / *factor
1064                        } else {
1065                            let smooth = (original_max_position_embeddings as f32 / wavelen
1066                                - low_freq_factor)
1067                                / (high_freq_factor - low_freq_factor);
1068                            (1. - smooth) * freq / *factor + smooth * freq
1069                        }
1070                    })
1071                    .collect::<Vec<_>>();
1072                let inv_freq_len = inv_freq.len();
1073                let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
1074                let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
1075                    .to_dtype(DType::F32)?
1076                    .reshape((cfg.max_position_embeddings, 1))?;
1077                let freqs = t.matmul(&inv_freq)?;
1078                let sin = freqs.sin()?.to_dtype(dtype)?;
1079                let cos = freqs.cos()?.to_dtype(dtype)?;
1080                Ok(Self(RotaryEmbedding {
1081                    sin,
1082                    cos,
1083                    is_gpt_neox,
1084                }))
1085            }
1086            Some(SmolLm3RopeConfig {
1087                rope_type: SmolLm3RopeType::Linear,
1088                factor,
1089                ..
1090            }) => {
1091                let inv_freq_vec = calculate_default_inv_freq_smollm3(cfg)
1092                    .into_iter()
1093                    .map(|freq| freq / *factor)
1094                    .collect::<Vec<_>>();
1095                let inv_freq_len = inv_freq_vec.len();
1096                let inv_freq = Tensor::from_vec(inv_freq_vec, (1, inv_freq_len), dev)?;
1097                let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
1098                    .to_dtype(DType::F32)?
1099                    .reshape((cfg.max_position_embeddings, 1))?;
1100                let freqs = t.matmul(&inv_freq)?;
1101                let sin = freqs.sin()?.to_dtype(dtype)?;
1102                let cos = freqs.cos()?.to_dtype(dtype)?;
1103                Ok(Self(RotaryEmbedding {
1104                    sin,
1105                    cos,
1106                    is_gpt_neox,
1107                }))
1108            }
1109        }
1110    }
1111
1112    pub fn forward(
1113        &self,
1114        q: &Tensor,
1115        k: &Tensor,
1116        seqlen_offsets: &[usize],
1117    ) -> Result<(Tensor, Tensor)> {
1118        self.0.forward(q, k, seqlen_offsets)
1119    }
1120}
1121
1122// https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L107
1123#[derive(Debug, Clone)]
1124pub struct Qwen2VLRotaryEmbedding {
1125    inv_freq: Tensor,
1126    mrope_section: Vec<usize>,
1127}
1128
1129impl Qwen2VLRotaryEmbedding {
1130    pub fn new(
1131        base: f32,
1132        head_dim: usize,
1133        device: &Device,
1134        mrope_section: Vec<usize>,
1135    ) -> Result<Self> {
1136        let inv_freq: Vec<_> = (0..head_dim)
1137            .step_by(2)
1138            .map(|i| 1f32 / base.powf(i as f32 / head_dim as f32))
1139            .collect();
1140        let inv_freq_len = inv_freq.len();
1141        let inv_freq = Tensor::from_vec(inv_freq, (inv_freq_len,), device)?.to_dtype(DType::F32)?;
1142        Ok(Self {
1143            inv_freq,
1144            mrope_section,
1145        })
1146    }
1147
1148    /// (cos, sin)
1149    pub fn compute_cos_sin(&self, position_ids: &Tensor, dtype: DType) -> Result<(Tensor, Tensor)> {
1150        let inv_freq_expanded =
1151            self.inv_freq
1152                .reshape((1, 1, (), 1))?
1153                .repeat((3, position_ids.dim(1)?, 1, 1))?;
1154        let position_ids_expanded = position_ids.unsqueeze(2)?;
1155        let freqs = inv_freq_expanded
1156            .matmul(&position_ids_expanded.to_dtype(inv_freq_expanded.dtype())?)?
1157            .transpose(2, 3)?;
1158        let cos = freqs.cos()?;
1159        let sin = freqs.sin()?;
1160
1161        let cos = Tensor::cat(
1162            &cos.split(&self.mrope_section, D::Minus1)?
1163                .into_iter()
1164                .enumerate()
1165                .map(|(i, m)| m.i(i % 3))
1166                .collect::<Result<Vec<_>>>()?,
1167            D::Minus1,
1168        )?
1169        .squeeze(0)?
1170        .to_dtype(dtype)?
1171        .contiguous()?;
1172        let sin = Tensor::cat(
1173            &sin.split(&self.mrope_section, D::Minus1)?
1174                .into_iter()
1175                .enumerate()
1176                .map(|(i, m)| m.i(i % 3))
1177                .collect::<Result<Vec<_>>>()?,
1178            D::Minus1,
1179        )?
1180        .squeeze(0)?
1181        .to_dtype(dtype)?
1182        .contiguous()?;
1183
1184        Ok((cos, sin))
1185    }
1186
1187    // https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L203
1188    pub fn forward(
1189        &self,
1190        (cos, sin): &(Tensor, Tensor),
1191        q: &mut Tensor,
1192        k: &mut Tensor,
1193    ) -> Result<()> {
1194        *q = candle_nn::rotary_emb::rope(&q.contiguous()?, cos, sin)?;
1195        *k = candle_nn::rotary_emb::rope(&k.contiguous()?, cos, sin)?;
1196        Ok(())
1197    }
1198}
1199
1200#[derive(Debug, Clone)]
1201pub struct Qwen2_5VLRotaryEmbedding {
1202    inv_freq: Tensor,
1203    mrope_section: Vec<usize>,
1204}
1205
1206impl Qwen2_5VLRotaryEmbedding {
1207    pub fn new(
1208        base: f32,
1209        head_dim: usize,
1210        device: &Device,
1211        mrope_section: Vec<usize>,
1212    ) -> Result<Self> {
1213        let inv_freq: Vec<_> = (0..head_dim)
1214            .step_by(2)
1215            .map(|i| 1f32 / base.powf(i as f32 / head_dim as f32))
1216            .collect();
1217        let inv_freq_len = inv_freq.len();
1218        let inv_freq = Tensor::from_vec(inv_freq, (inv_freq_len,), device)?.to_dtype(DType::F32)?;
1219        Ok(Self {
1220            inv_freq,
1221            mrope_section,
1222        })
1223    }
1224
1225    /// (cos, sin)
1226    pub fn compute_cos_sin(&self, position_ids: &Tensor, dtype: DType) -> Result<(Tensor, Tensor)> {
1227        let inv_freq_expanded =
1228            self.inv_freq
1229                .reshape((1, 1, (), 1))?
1230                .repeat((3, position_ids.dim(1)?, 1, 1))?;
1231        let position_ids_expanded = position_ids.unsqueeze(2)?;
1232        let freqs = inv_freq_expanded
1233            .matmul(&position_ids_expanded.to_dtype(inv_freq_expanded.dtype())?)?
1234            .transpose(2, 3)?;
1235        let cos = freqs.cos()?;
1236        let sin = freqs.sin()?;
1237
1238        let cos = Tensor::cat(
1239            &cos.split(&self.mrope_section, D::Minus1)?
1240                .into_iter()
1241                .enumerate()
1242                .map(|(i, m)| m.i(i % 3))
1243                .collect::<Result<Vec<_>>>()?,
1244            D::Minus1,
1245        )?
1246        .squeeze(0)?
1247        .to_dtype(dtype)?
1248        .contiguous()?;
1249        let sin = Tensor::cat(
1250            &sin.split(&self.mrope_section, D::Minus1)?
1251                .into_iter()
1252                .enumerate()
1253                .map(|(i, m)| m.i(i % 3))
1254                .collect::<Result<Vec<_>>>()?,
1255            D::Minus1,
1256        )?
1257        .squeeze(0)?
1258        .to_dtype(dtype)?
1259        .contiguous()?;
1260
1261        Ok((cos, sin))
1262    }
1263
1264    pub fn forward(
1265        &self,
1266        (cos, sin): &(Tensor, Tensor),
1267        q: &mut Tensor,
1268        k: &mut Tensor,
1269    ) -> Result<()> {
1270        *q = candle_nn::rotary_emb::rope(&q.contiguous()?, cos, sin)?;
1271        *k = candle_nn::rotary_emb::rope(&k.contiguous()?, cos, sin)?;
1272        Ok(())
1273    }
1274}
1275
1276#[derive(Debug, Clone)]
1277pub struct Qwen3VLRotaryEmbedding {
1278    inv_freq: Tensor,
1279    mrope_section: Vec<usize>,
1280}
1281
1282impl Qwen3VLRotaryEmbedding {
1283    pub fn new(
1284        base: f32,
1285        head_dim: usize,
1286        device: &Device,
1287        mrope_section: Vec<usize>,
1288    ) -> Result<Self> {
1289        let inv_freq: Vec<_> = (0..head_dim)
1290            .step_by(2)
1291            .map(|i| 1f32 / base.powf(i as f32 / head_dim as f32))
1292            .collect();
1293        let inv_freq_len = inv_freq.len();
1294        let inv_freq = Tensor::from_vec(inv_freq, (inv_freq_len,), device)?.to_dtype(DType::F32)?;
1295        Ok(Self {
1296            inv_freq,
1297            mrope_section,
1298        })
1299    }
1300
1301    /// (cos, sin)
1302    pub fn compute_cos_sin(&self, position_ids: &Tensor, dtype: DType) -> Result<(Tensor, Tensor)> {
1303        let inv_freq_expanded =
1304            self.inv_freq
1305                .reshape((1, 1, (), 1))?
1306                .repeat((3, position_ids.dim(1)?, 1, 1))?;
1307        let position_ids_expanded = position_ids.unsqueeze(2)?;
1308        let freqs = inv_freq_expanded
1309            .matmul(&position_ids_expanded.to_dtype(inv_freq_expanded.dtype())?)?
1310            .transpose(2, 3)?;
1311        let cos = freqs.cos()?;
1312        let sin = freqs.sin()?;
1313
1314        let cos = Tensor::cat(
1315            &cos.split(&self.mrope_section, D::Minus1)?
1316                .into_iter()
1317                .enumerate()
1318                .map(|(i, m)| m.i(i % 3))
1319                .collect::<Result<Vec<_>>>()?,
1320            D::Minus1,
1321        )?
1322        .squeeze(0)?
1323        .to_dtype(dtype)?
1324        .contiguous()?;
1325        let sin = Tensor::cat(
1326            &sin.split(&self.mrope_section, D::Minus1)?
1327                .into_iter()
1328                .enumerate()
1329                .map(|(i, m)| m.i(i % 3))
1330                .collect::<Result<Vec<_>>>()?,
1331            D::Minus1,
1332        )?
1333        .squeeze(0)?
1334        .to_dtype(dtype)?
1335        .contiguous()?;
1336
1337        Ok((cos, sin))
1338    }
1339
1340    pub fn forward(
1341        &self,
1342        (cos, sin): &(Tensor, Tensor),
1343        q: &mut Tensor,
1344        k: &mut Tensor,
1345    ) -> Result<()> {
1346        *q = candle_nn::rotary_emb::rope(&q.contiguous()?, cos, sin)?;
1347        *k = candle_nn::rotary_emb::rope(&k.contiguous()?, cos, sin)?;
1348        Ok(())
1349    }
1350}
1351
1352#[derive(Debug, Clone)]
1353pub struct DeepSeekV2RotaryEmbedding {
1354    sin: Tensor,
1355    cos: Tensor,
1356}
1357
1358#[derive(Debug, Clone, Deserialize, Serialize)]
1359#[serde(untagged)]
1360pub enum DeepSeekV2RopeScaling {
1361    Yarn {
1362        original_max_position_embeddings: usize,
1363        beta_fast: f32,
1364        beta_slow: f32,
1365        mscale: f32,
1366        mscale_all_dim: f32,
1367        factor: f32,
1368        #[serde(rename = "type")]
1369        scaling_type: ScaledRopeType,
1370    },
1371    LinearOrDynamic {
1372        #[serde(rename = "type")]
1373        scaling_type: ScaledRopeType,
1374        factor: f64,
1375    },
1376}
1377
1378pub struct DeepSeekV2RopeConfig {
1379    pub rope_scaling: Option<DeepSeekV2RopeScaling>,
1380    pub max_position_embeddings: usize,
1381    pub rope_theta: f32,
1382    pub qk_rope_head_dim: usize,
1383}
1384
1385impl DeepSeekV2RotaryEmbedding {
1386    fn new_unscaled(cfg: &DeepSeekV2RopeConfig, dtype: DType, dev: &Device) -> Result<Self> {
1387        let max_seq_len = cfg.max_position_embeddings;
1388        let dim = cfg.qk_rope_head_dim;
1389
1390        let inv_freq: Vec<_> = (0..dim)
1391            .step_by(2)
1392            .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / dim as f32))
1393            .collect();
1394        let inv_freq_len = inv_freq.len();
1395        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
1396        let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
1397            .to_dtype(DType::F32)?
1398            .reshape((max_seq_len, 1))?;
1399        let freqs = t.matmul(&inv_freq)?;
1400
1401        let sin = freqs.sin()?.to_dtype(dtype)?;
1402        let cos = freqs.cos()?.to_dtype(dtype)?;
1403
1404        Ok(Self { sin, cos })
1405    }
1406
1407    fn yarn_find_correction_dim(
1408        num_rot: f32,
1409        dim: usize,
1410        base: f32,
1411        max_position_embeddings: usize,
1412    ) -> f32 {
1413        (dim as f32 * (max_position_embeddings as f32 / (num_rot * 2. * PI)).ln())
1414            / (2. * base.ln())
1415    }
1416
1417    fn yarn_find_correction_range(
1418        low_rot: f32,
1419        high_rot: f32,
1420        dim: usize,
1421        base: f32,
1422        max_position_embeddings: usize,
1423    ) -> (f32, f32) {
1424        let low =
1425            Self::yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings).floor();
1426        let high =
1427            Self::yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings).ceil();
1428        (low.max(0.), high.min(dim as f32 - 1.))
1429    }
1430
1431    fn yarn_linear_ramp_mask(min: f32, mut max: f32, dim: usize, dev: &Device) -> Result<Tensor> {
1432        if min == max {
1433            // https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite/blob/604d5664dddd88a0433dbae533b7fe9472482de0/modeling_deepseek.py#L255
1434            max += 0.001;
1435        }
1436        let linear_func =
1437            ((Tensor::arange(0f32, dim as f32, dev)? - min as f64)? / (max as f64 - min as f64))?;
1438        linear_func.clamp(0., 1)
1439    }
1440
1441    pub(crate) fn yarn_get_mscale(scale: f32, mscale: f32) -> f32 {
1442        if scale <= 1. {
1443            return 1.;
1444        }
1445        0.1 * mscale * scale.ln() + 1.
1446    }
1447
1448    #[allow(clippy::too_many_arguments)]
1449    fn new_yarn(
1450        cfg: &DeepSeekV2RopeConfig,
1451        dtype: DType,
1452        dev: &Device,
1453        original_max_position_embeddings: usize,
1454        beta_fast: f32,
1455        beta_slow: f32,
1456        factor: f32,
1457        mscale: f32,
1458        mscale_all_dim: f32,
1459    ) -> Result<Self> {
1460        let freq_extra: Vec<_> = (0..cfg.qk_rope_head_dim)
1461            .step_by(2)
1462            .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / cfg.qk_rope_head_dim as f32))
1463            .collect();
1464        let freq_extra_len = freq_extra.len();
1465        let freq_extra = Tensor::from_vec(freq_extra, freq_extra_len, dev)?;
1466        let freq_inter: Vec<_> = (0..cfg.qk_rope_head_dim)
1467            .step_by(2)
1468            .map(|i| 1f32 / (factor * cfg.rope_theta.powf(i as f32 / cfg.qk_rope_head_dim as f32)))
1469            .collect();
1470        let freq_inter_len = freq_inter.len();
1471        let freq_inter = Tensor::from_vec(freq_inter, (1, freq_inter_len), dev)?;
1472
1473        let (low, high) = Self::yarn_find_correction_range(
1474            beta_fast,
1475            beta_slow,
1476            cfg.qk_rope_head_dim,
1477            cfg.rope_theta,
1478            original_max_position_embeddings,
1479        );
1480        let inv_freq_mask =
1481            (1. - Self::yarn_linear_ramp_mask(low, high, cfg.qk_rope_head_dim / 2, dev)?)?;
1482        let inv_freq = freq_inter
1483            .broadcast_mul(&(1. - &inv_freq_mask)?)?
1484            .broadcast_add(&freq_extra.broadcast_mul(&inv_freq_mask)?)?;
1485
1486        let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
1487            .to_dtype(DType::F32)?
1488            .reshape((cfg.max_position_embeddings, 1))?;
1489        let freqs = t.matmul(&inv_freq)?;
1490
1491        let mscale =
1492            Self::yarn_get_mscale(factor, mscale) / Self::yarn_get_mscale(factor, mscale_all_dim);
1493        let sin = (freqs.sin()? * mscale as f64)?.to_dtype(dtype)?;
1494        let cos = (freqs.cos()? * mscale as f64)?.to_dtype(dtype)?;
1495
1496        Ok(Self { sin, cos })
1497    }
1498
1499    pub fn new(cfg: &DeepSeekV2RopeConfig, dtype: DType, dev: &Device) -> Result<Self> {
1500        match &cfg.rope_scaling {
1501            Some(DeepSeekV2RopeScaling::LinearOrDynamic {
1502                scaling_type: _,
1503                factor: _,
1504            }) => candle_core::bail!("linear and dynamic rope are not implemented yet!"),
1505            Some(DeepSeekV2RopeScaling::Yarn {
1506                original_max_position_embeddings,
1507                beta_fast,
1508                beta_slow,
1509                factor,
1510                mscale,
1511                mscale_all_dim,
1512                scaling_type: _,
1513            }) => Self::new_yarn(
1514                cfg,
1515                dtype,
1516                dev,
1517                *original_max_position_embeddings,
1518                *beta_fast,
1519                *beta_slow,
1520                *factor,
1521                *mscale,
1522                *mscale_all_dim,
1523            ),
1524            None => Self::new_unscaled(cfg, dtype, dev),
1525        }
1526    }
1527
1528    pub fn forward(
1529        &self,
1530        q: &Tensor,
1531        k: &Tensor,
1532        seqlen_offsets: &[usize],
1533    ) -> Result<(Tensor, Tensor)> {
1534        let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
1535
1536        if seqlen_offsets.len() == 1 {
1537            let cos = self.cos.narrow(0, seqlen_offsets[0], seq_len)?;
1538            let sin = self.sin.narrow(0, seqlen_offsets[0], seq_len)?;
1539            let q_embed = candle_nn::rotary_emb::rope_i(&q.contiguous()?, &cos, &sin)?;
1540            let k_embed = candle_nn::rotary_emb::rope_i(&k.contiguous()?, &cos, &sin)?;
1541            Ok((q_embed, k_embed))
1542        } else {
1543            let mut q_embeds = Vec::new();
1544            let mut k_embeds = Vec::new();
1545            for (i, offset) in seqlen_offsets.iter().enumerate() {
1546                let cos = self.cos.narrow(0, *offset, seq_len)?;
1547                let sin = self.sin.narrow(0, *offset, seq_len)?;
1548                let q_embed = candle_nn::rotary_emb::rope_i(
1549                    &q.i(i)?.unsqueeze(0)?.contiguous()?,
1550                    &cos,
1551                    &sin,
1552                )?;
1553                let k_embed = candle_nn::rotary_emb::rope_i(
1554                    &k.i(i)?.unsqueeze(0)?.contiguous()?,
1555                    &cos,
1556                    &sin,
1557                )?;
1558                q_embeds.push(q_embed);
1559                k_embeds.push(k_embed);
1560            }
1561            Ok((Tensor::cat(&q_embeds, 0)?, Tensor::cat(&k_embeds, 0)?))
1562        }
1563    }
1564}
1565
1566#[derive(Debug, Clone)]
1567pub struct Phi4MMRotaryEmbedding {
1568    short_sin: Tensor,
1569    short_cos: Tensor,
1570    long_cos: Option<Tensor>,
1571    long_sin: Option<Tensor>,
1572    original_max_position_embeddings: usize,
1573}
1574
1575#[derive(Debug, Clone, Default, Deserialize, Serialize)]
1576#[serde(rename_all = "lowercase")]
1577pub enum Phi4MMScaledRopeType {
1578    #[serde(alias = "longrope")]
1579    LongRope,
1580    #[default]
1581    Default,
1582}
1583
1584#[derive(Debug, Clone, Deserialize, Serialize)]
1585pub struct Phi4MMRopeScalingConfig {
1586    short_factor: Option<Vec<f64>>,
1587    long_factor: Option<Vec<f64>>,
1588    #[serde(rename = "type")]
1589    scaling_type: Phi4MMScaledRopeType,
1590}
1591
1592impl Phi4MMRotaryEmbedding {
1593    fn new_unscaled(cfg: &Phi4MMConfig, dtype: DType, dev: &Device) -> Result<Self> {
1594        let max_seq_len = cfg.max_position_embeddings;
1595        let dim = (cfg.head_dim() as f64 * cfg.partial_rotary_factor) as usize;
1596
1597        let inv_freq: Vec<_> = (0..dim)
1598            .step_by(2)
1599            .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
1600            .collect();
1601        let inv_freq_len = inv_freq.len();
1602        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
1603        let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
1604            .to_dtype(DType::F32)?
1605            .reshape((max_seq_len, 1))?;
1606        let freqs = t.matmul(&inv_freq)?;
1607        let sin = freqs.sin()?.to_dtype(dtype)?;
1608        let cos = freqs.cos()?.to_dtype(dtype)?;
1609        Ok(Self {
1610            short_cos: cos,
1611            short_sin: sin,
1612            long_cos: None,
1613            long_sin: None,
1614            original_max_position_embeddings: cfg.original_max_position_embeddings,
1615        })
1616    }
1617
1618    #[allow(clippy::too_many_arguments)]
1619    fn new_longrope(
1620        short_factor: &[f64],
1621        long_factor: &[f64],
1622        cfg: &Phi4MMConfig,
1623        dtype: DType,
1624        dev: &Device,
1625    ) -> Result<Self> {
1626        let max_seq_len = cfg.max_position_embeddings;
1627        let dim = (cfg.head_dim() as f64 * cfg.partial_rotary_factor) as usize;
1628
1629        // Calculate scale
1630        let scale =
1631            cfg.max_position_embeddings as f64 / cfg.original_max_position_embeddings as f64;
1632        let scaling_factor = if scale <= 1.0 {
1633            1.0
1634        } else {
1635            (1.0 + scale.ln() / (cfg.original_max_position_embeddings as f64).ln()).sqrt()
1636        };
1637
1638        // Short cos/sin
1639        let inv_freq_short: Vec<_> = (0..dim)
1640            .step_by(2)
1641            .enumerate()
1642            .map(|(k, i)| {
1643                1f32 / (short_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32
1644            })
1645            .collect();
1646        let inv_freq_len_short = inv_freq_short.len();
1647        let inv_freq_short = Tensor::from_vec(inv_freq_short, (1, inv_freq_len_short), dev)?;
1648        let t_short = Tensor::arange(0u32, max_seq_len as u32, dev)?
1649            .to_dtype(DType::F32)?
1650            .reshape((max_seq_len, 1))?;
1651        let freqs_short = t_short.matmul(&inv_freq_short)?;
1652        let sin_short = (freqs_short.sin()?.to_dtype(dtype)? * scaling_factor)?;
1653        let cos_short = (freqs_short.cos()?.to_dtype(dtype)? * scaling_factor)?;
1654
1655        // Long cos/sin
1656        let inv_freq_long: Vec<_> = (0..dim)
1657            .step_by(2)
1658            .enumerate()
1659            .map(|(k, i)| {
1660                1f32 / (long_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32
1661            })
1662            .collect();
1663        let inv_freq_len_long = inv_freq_long.len();
1664        let inv_freq_long = Tensor::from_vec(inv_freq_long, (1, inv_freq_len_long), dev)?;
1665        let t_long = Tensor::arange(0u32, max_seq_len as u32, dev)?
1666            .to_dtype(DType::F32)?
1667            .reshape((max_seq_len, 1))?;
1668        let freqs_long = t_long.matmul(&inv_freq_long)?;
1669        let sin_long = (freqs_long.sin()?.to_dtype(dtype)? * scaling_factor)?;
1670        let cos_long = (freqs_long.cos()?.to_dtype(dtype)? * scaling_factor)?;
1671
1672        Ok(Self {
1673            short_cos: cos_short,
1674            short_sin: sin_short,
1675            long_cos: Some(cos_long),
1676            long_sin: Some(sin_long),
1677            original_max_position_embeddings: cfg.original_max_position_embeddings,
1678        })
1679    }
1680
1681    pub fn new(dtype: DType, cfg: &Phi4MMConfig, dev: &Device) -> Result<Self> {
1682        match &cfg.rope_scaling {
1683            Some(Phi4MMRopeScalingConfig {
1684                scaling_type: Phi4MMScaledRopeType::LongRope,
1685                short_factor: Some(short_factor),
1686                long_factor: Some(long_factor),
1687            }) => Self::new_longrope(short_factor, long_factor, cfg, dtype, dev),
1688
1689            _ => Self::new_unscaled(cfg, dtype, dev),
1690        }
1691    }
1692
1693    /// Returns (sin, cos) taking into account LongRope
1694    fn get_long_or_short_sin_cos(&self, position_ids: &[usize]) -> (&Tensor, &Tensor) {
1695        if self.long_cos.is_none() {
1696            return (&self.short_sin, &self.short_cos);
1697        }
1698        let seq_len = position_ids.iter().max().unwrap() + 1;
1699        if seq_len > self.original_max_position_embeddings {
1700            (
1701                self.long_sin.as_ref().unwrap(),
1702                self.long_cos.as_ref().unwrap(),
1703            )
1704        } else {
1705            (&self.short_sin, &self.short_cos)
1706        }
1707    }
1708
1709    pub fn forward(
1710        &self,
1711        q: &Tensor,
1712        k: &Tensor,
1713        seqlen_offsets: &[usize],
1714        position_ids: &[usize],
1715    ) -> Result<(Tensor, Tensor)> {
1716        let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
1717        let (sin, cos) = self.get_long_or_short_sin_cos(position_ids);
1718
1719        let rot_dim = cos.dim(D::Minus1)? * 2;
1720        let q_rot = q.narrow(D::Minus1, 0, rot_dim)?;
1721        let q_pass = q.narrow(D::Minus1, rot_dim, q.dim(D::Minus1)? - rot_dim)?;
1722        let k_rot = k.narrow(D::Minus1, 0, rot_dim)?;
1723        let k_pass = k.narrow(D::Minus1, rot_dim, k.dim(D::Minus1)? - rot_dim)?;
1724
1725        let (q_rot, k_rot) = if seqlen_offsets.len() == 1 {
1726            let cos = cos.narrow(0, seqlen_offsets[0], seq_len)?;
1727            let sin = sin.narrow(0, seqlen_offsets[0], seq_len)?;
1728            let q_embed = candle_nn::rotary_emb::rope(&q_rot.contiguous()?, &cos, &sin)?;
1729            let k_embed = candle_nn::rotary_emb::rope(&k_rot.contiguous()?, &cos, &sin)?;
1730            (q_embed, k_embed)
1731        } else {
1732            let mut q_embeds = Vec::new();
1733            let mut k_embeds = Vec::new();
1734            for (i, offset) in seqlen_offsets.iter().enumerate() {
1735                let cos = cos.narrow(0, *offset, seq_len)?;
1736                let sin = sin.narrow(0, *offset, seq_len)?;
1737                let q_embed = candle_nn::rotary_emb::rope(
1738                    &q_rot.i(i)?.unsqueeze(0)?.contiguous()?,
1739                    &cos,
1740                    &sin,
1741                )?;
1742                let k_embed = candle_nn::rotary_emb::rope(
1743                    &k_rot.i(i)?.unsqueeze(0)?.contiguous()?,
1744                    &cos,
1745                    &sin,
1746                )?;
1747                q_embeds.push(q_embed);
1748                k_embeds.push(k_embed);
1749            }
1750            let q_rot = Tensor::cat(&q_embeds, 0)?;
1751            let k_rot = Tensor::cat(&k_embeds, 0)?;
1752            (q_rot, k_rot)
1753        };
1754
1755        Ok((
1756            Tensor::cat(&[q_rot, q_pass], D::Minus1)?.contiguous()?,
1757            Tensor::cat(&[k_rot, k_pass], D::Minus1)?.contiguous()?,
1758        ))
1759    }
1760}
1761
1762#[derive(Debug, Clone)]
1763pub struct Gemma3nRotaryEmbedding(RotaryEmbedding);
1764
1765#[derive(Debug, Clone, Deserialize, Serialize)]
1766#[serde(rename_all = "lowercase")]
1767pub enum Gemma3nScaledRopeType {
1768    #[serde(alias = "linear")]
1769    Linear,
1770}
1771
1772#[derive(Debug, Clone, Deserialize, Serialize)]
1773pub struct Gemma3nRopeScalingConfig {
1774    factor: f64,
1775    rope_type: Gemma3nScaledRopeType,
1776}
1777
1778impl Gemma3nRotaryEmbedding {
1779    fn new_linear(
1780        cfg: &Gemma3nTextConfig,
1781        factor: f64,
1782        is_gpt_neox: bool,
1783        dtype: DType,
1784        dev: &Device,
1785    ) -> Result<Self> {
1786        let max_seq_len = cfg.max_position_embeddings;
1787        let dim = cfg.head_dim;
1788
1789        let inv_freq: Vec<_> = (0..dim)
1790            .step_by(2)
1791            .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
1792            .collect();
1793        let inv_freq_len = inv_freq.len();
1794        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
1795        let inv_freq = (inv_freq / factor)?;
1796
1797        let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
1798            .to_dtype(DType::F32)?
1799            .reshape((max_seq_len, 1))?;
1800        let freqs = t.matmul(&inv_freq)?;
1801        let sin = freqs.sin()?.to_dtype(dtype)?;
1802        let cos = freqs.cos()?.to_dtype(dtype)?;
1803        Ok(Self(RotaryEmbedding {
1804            cos,
1805            sin,
1806            is_gpt_neox,
1807        }))
1808    }
1809
1810    pub fn new(
1811        is_gpt_neox: bool,
1812        dtype: DType,
1813        cfg: &Gemma3nTextConfig,
1814        dev: &Device,
1815    ) -> Result<Self> {
1816        match &cfg.rope_scaling {
1817            Some(Gemma3RopeScalingConfig {
1818                rope_type: Gemma3ScaledRopeType::Linear,
1819                factor,
1820            }) => Self::new_linear(cfg, *factor, is_gpt_neox, dtype, dev),
1821
1822            _ => Self::new_linear(cfg, 1.0, is_gpt_neox, dtype, dev),
1823        }
1824    }
1825
1826    pub fn get_cos_sin(&self) -> Result<(Tensor, Tensor)> {
1827        self.0.get_cos_sin()
1828    }
1829
1830    pub fn forward(
1831        &self,
1832        q: &Tensor,
1833        k: &Tensor,
1834        seqlen_offsets: &[usize],
1835    ) -> Result<(Tensor, Tensor)> {
1836        self.0.forward(q, k, seqlen_offsets)
1837    }
1838}
1839
1840#[derive(Debug, Clone)]
1841pub struct Gemma3RotaryEmbedding(RotaryEmbedding);
1842
1843#[derive(Debug, Clone, Deserialize, Serialize)]
1844#[serde(rename_all = "lowercase")]
1845pub enum Gemma3ScaledRopeType {
1846    #[serde(alias = "linear")]
1847    Linear,
1848}
1849
1850#[derive(Debug, Clone, Deserialize, Serialize)]
1851pub struct Gemma3RopeScalingConfig {
1852    factor: f64,
1853    rope_type: Gemma3ScaledRopeType,
1854}
1855
1856impl Gemma3RotaryEmbedding {
1857    fn new_linear(
1858        cfg: &Gemma3TextConfig,
1859        factor: f64,
1860        is_gpt_neox: bool,
1861        dtype: DType,
1862        dev: &Device,
1863    ) -> Result<Self> {
1864        let max_seq_len = cfg.max_position_embeddings;
1865        let dim = cfg.head_dim;
1866
1867        let inv_freq: Vec<_> = (0..dim)
1868            .step_by(2)
1869            .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
1870            .collect();
1871        let inv_freq_len = inv_freq.len();
1872        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
1873        let inv_freq = (inv_freq / factor)?;
1874
1875        let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
1876            .to_dtype(DType::F32)?
1877            .reshape((max_seq_len, 1))?;
1878        let freqs = t.matmul(&inv_freq)?;
1879        let sin = freqs.sin()?.to_dtype(dtype)?;
1880        let cos = freqs.cos()?.to_dtype(dtype)?;
1881        Ok(Self(RotaryEmbedding {
1882            cos,
1883            sin,
1884            is_gpt_neox,
1885        }))
1886    }
1887
1888    pub fn new(
1889        is_gpt_neox: bool,
1890        dtype: DType,
1891        cfg: &Gemma3TextConfig,
1892        dev: &Device,
1893    ) -> Result<Self> {
1894        match &cfg.rope_scaling {
1895            Some(Gemma3RopeScalingConfig {
1896                rope_type: Gemma3ScaledRopeType::Linear,
1897                factor,
1898            }) => Self::new_linear(cfg, *factor, is_gpt_neox, dtype, dev),
1899
1900            _ => Self::new_linear(cfg, 1.0, is_gpt_neox, dtype, dev),
1901        }
1902    }
1903
1904    pub fn forward(
1905        &self,
1906        q: &Tensor,
1907        k: &Tensor,
1908        seqlen_offsets: &[usize],
1909    ) -> Result<(Tensor, Tensor)> {
1910        self.0.forward(q, k, seqlen_offsets)
1911    }
1912}
1913
1914pub struct DiaRotaryEmbedding {
1915    timescale: Tensor,
1916    dtype: DType,
1917}
1918
1919impl DiaRotaryEmbedding {
1920    pub fn new(
1921        min_timescale: f32,
1922        max_timescale: f32,
1923        head_dim: usize,
1924        device: &Device,
1925        dtype: DType,
1926    ) -> Result<Self> {
1927        assert_eq!(head_dim % 2, 0);
1928        let half_embedding_dim = head_dim / 2;
1929
1930        let fraction = (0..half_embedding_dim).map(|i| 2f32 * i as f32 / head_dim as f32);
1931        let timescale = fraction
1932            .into_iter()
1933            .map(|x| min_timescale * (max_timescale / min_timescale).powf(x))
1934            .collect::<Vec<_>>();
1935
1936        let timescale_len = timescale.len();
1937        let timescale = Tensor::from_vec(timescale, timescale_len, device)?;
1938
1939        Ok(Self { timescale, dtype })
1940    }
1941
1942    pub fn forward(&self, xs: &Tensor, positions: &Tensor) -> Result<Tensor> {
1943        let freqs = positions
1944            .unsqueeze(D::Minus1)?
1945            .unsqueeze(D::Minus1)?
1946            .broadcast_div(&self.timescale)?;
1947
1948        let sin = freqs.sin()?.to_dtype(self.dtype)?;
1949        let cos = freqs.cos()?.to_dtype(self.dtype)?;
1950
1951        let split = xs.chunk(2, D::Minus1)?;
1952        let first_half = &split[0];
1953        let second_half = &split[1];
1954
1955        let first_part = (first_half.broadcast_mul(&cos)? - second_half.broadcast_mul(&sin)?)?;
1956        let second_part = (second_half.broadcast_mul(&cos)? + first_half.broadcast_mul(&sin)?)?;
1957
1958        Tensor::cat(&[first_part, second_part], D::Minus1)
1959    }
1960}
1961#[derive(Debug, Clone)]
1962pub struct QLinear {
1963    inner: QMatMul,
1964    bias: Option<Tensor>,
1965    dtype: DType,
1966}
1967
1968impl QLinear {
1969    pub fn new<R: std::io::Read + std::io::Seek>(
1970        ct: &mut Content<'_, R>,
1971        name: &str,
1972        device: &Device,
1973    ) -> Result<Self> {
1974        let w = ct.tensor(&format!("{name}.weight"), device)?;
1975        let b = ct.tensor(&format!("{name}.bias"), device)?;
1976        let inner = QMatMul::from_qtensor(w)?;
1977        let bias = b.dequantize(device)?;
1978        Ok(Self {
1979            inner,
1980            bias: Some(bias),
1981            dtype: DType::F32,
1982        })
1983    }
1984
1985    pub fn from_linear(linear: Linear) -> Self {
1986        Self {
1987            inner: QMatMul::Tensor(linear.weight().clone()),
1988            bias: linear.bias().cloned(),
1989            dtype: linear.weight().dtype(),
1990        }
1991    }
1992
1993    pub fn from_parts(w: Tensor, b: Option<Tensor>) -> Self {
1994        let dtype = w.dtype();
1995        Self {
1996            inner: QMatMul::Tensor(w),
1997            bias: b,
1998            dtype,
1999        }
2000    }
2001
2002    pub fn from_qparts(w: QTensor, b: Option<Tensor>) -> Self {
2003        if let Some(ref b) = b {
2004            assert_eq!(b.dtype(), DType::F32);
2005        }
2006        Self {
2007            inner: QMatMul::QTensor(Arc::new(w)),
2008            bias: b,
2009            dtype: DType::F32,
2010        }
2011    }
2012
2013    pub fn from_old_and_qmatmul(inner: QMatMul, old: &Self) -> Self {
2014        Self {
2015            inner,
2016            bias: old.bias.clone(),
2017            dtype: old.dtype,
2018        }
2019    }
2020
2021    pub fn inner(&mut self) -> &mut QMatMul {
2022        &mut self.inner
2023    }
2024
2025    pub fn inner_ref(&self) -> &QMatMul {
2026        &self.inner
2027    }
2028
2029    pub fn is_quant(&self) -> bool {
2030        matches!(self.inner, QMatMul::QTensor(_))
2031    }
2032
2033    pub fn bias(&self) -> Option<&Tensor> {
2034        self.bias.as_ref()
2035    }
2036
2037    pub fn bias_mut(&mut self) -> Option<&mut Tensor> {
2038        self.bias.as_mut()
2039    }
2040}
2041
2042impl Module for QLinear {
2043    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2044        let xs = if self.is_quant() {
2045            xs.to_dtype(DType::F32)?
2046        } else {
2047            xs.clone()
2048        };
2049        if let Some(bias) = &self.bias {
2050            self.inner
2051                .forward(&xs)?
2052                .broadcast_add(bias)?
2053                .to_dtype(self.dtype)
2054        } else {
2055            self.inner.forward(&xs)?.to_dtype(self.dtype)
2056        }
2057    }
2058}
2059
2060#[derive(Debug, Clone)]
2061pub struct RotaryEmbedding {
2062    cos: Tensor,
2063    sin: Tensor,
2064    is_gpt_neox: bool,
2065}
2066
2067impl RotaryEmbedding {
2068    pub fn new(
2069        base: f32,
2070        head_dim: usize,
2071        max_position_embeddings: usize,
2072        device: &Device,
2073        is_gpt_neox: bool,
2074        dtype: DType,
2075    ) -> Result<Self> {
2076        let inv_freq: Vec<_> = (0..head_dim)
2077            .step_by(2)
2078            .map(|i| 1f32 / base.powf(i as f32 / head_dim as f32))
2079            .collect();
2080        let inv_freq_len = inv_freq.len();
2081        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), device)?;
2082        let t = Tensor::arange(0u32, max_position_embeddings as u32, device)?
2083            .to_dtype(DType::F32)?
2084            .reshape((max_position_embeddings, 1))?;
2085        let freqs = t.matmul(&inv_freq)?;
2086        let sin = freqs.sin()?.to_dtype(dtype)?;
2087        let cos = freqs.cos()?.to_dtype(dtype)?;
2088
2089        Ok(Self {
2090            cos,
2091            sin,
2092            is_gpt_neox,
2093        })
2094    }
2095
2096    pub fn get_cos_sin(&self) -> Result<(Tensor, Tensor)> {
2097        Ok((self.cos.clone(), self.sin.clone()))
2098    }
2099
2100    pub fn new_partial(
2101        base: f32,
2102        rot_dim: usize,
2103        max_position_embeddings: usize,
2104        device: &Device,
2105        is_gpt_neox: bool,
2106        dtype: DType,
2107    ) -> Result<Self> {
2108        let inv_freq: Vec<_> = (0..rot_dim)
2109            .step_by(2)
2110            .map(|i| 1f32 / base.powf(i as f32 / rot_dim as f32))
2111            .collect();
2112        let inv_freq_len = inv_freq.len();
2113        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), device)?;
2114        let t = Tensor::arange(0u32, max_position_embeddings as u32, device)?
2115            .to_dtype(DType::F32)?
2116            .reshape((max_position_embeddings, 1))?;
2117        let freqs = t.matmul(&inv_freq)?;
2118        let sin = freqs.sin()?.to_dtype(dtype)?;
2119        let cos = freqs.cos()?.to_dtype(dtype)?;
2120
2121        Ok(Self {
2122            cos,
2123            sin,
2124            is_gpt_neox,
2125        })
2126    }
2127
2128    pub fn forward(
2129        &self,
2130        q: &Tensor,
2131        k: &Tensor,
2132        seqlen_offsets: &[usize],
2133    ) -> Result<(Tensor, Tensor)> {
2134        let (b_sz, qh, seq_len, n_embd) = q.dims4()?;
2135        let (_b_sz, kh, _seq_len, __n_embd) = k.dims4()?;
2136
2137        let rope = if self.is_gpt_neox {
2138            candle_nn::rotary_emb::rope
2139        } else {
2140            candle_nn::rotary_emb::rope_i
2141        };
2142
2143        if cfg!(feature = "cuda") && qh == kh {
2144            let (cos, sin) = if seqlen_offsets.len() == 1 {
2145                (
2146                    self.cos.narrow(0, seqlen_offsets[0], seq_len)?,
2147                    self.sin.narrow(0, seqlen_offsets[0], seq_len)?,
2148                )
2149            } else {
2150                let mut cos_s = Vec::new();
2151                let mut sin_s = Vec::new();
2152                for offset in seqlen_offsets {
2153                    cos_s.push(self.cos.narrow(0, *offset, seq_len)?);
2154                    sin_s.push(self.sin.narrow(0, *offset, seq_len)?);
2155                }
2156                (Tensor::cat(&cos_s, 0)?, Tensor::cat(&sin_s, 0)?)
2157            };
2158
2159            let q_embed = q.transpose(1, 2)?.flatten(0, 1)?;
2160            let k_embed = k.transpose(1, 2)?.flatten(0, 1)?;
2161            mistralrs_quant::rotary::apply_rotary_inplace(
2162                &q_embed,
2163                &k_embed,
2164                &cos,
2165                &sin,
2166                self.is_gpt_neox,
2167            )?;
2168            let mut q = q_embed
2169                .reshape((b_sz, seq_len, qh, n_embd))?
2170                .transpose(1, 2)?;
2171            let mut k = k_embed
2172                .reshape((b_sz, seq_len, kh, n_embd))?
2173                .transpose(1, 2)?;
2174            if !(cfg!(feature = "flash-attn") || cfg!(feature = "flash-attn-v3")) {
2175                q = q.contiguous()?;
2176                k = k.contiguous()?;
2177            }
2178            Ok((q, k))
2179        } else if seqlen_offsets.len() == 1 {
2180            let cos = self.cos.narrow(0, seqlen_offsets[0], seq_len)?;
2181            let sin = self.sin.narrow(0, seqlen_offsets[0], seq_len)?;
2182            let q_embed = rope(&q.contiguous()?, &cos, &sin)?;
2183            let k_embed = rope(&k.contiguous()?, &cos, &sin)?;
2184            Ok((q_embed, k_embed))
2185        } else {
2186            let mut q_embeds = Vec::new();
2187            let mut k_embeds = Vec::new();
2188            for (i, offset) in seqlen_offsets.iter().enumerate() {
2189                let cos = self.cos.narrow(0, *offset, seq_len)?;
2190                let sin = self.sin.narrow(0, *offset, seq_len)?;
2191                let q_embed = rope(&q.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
2192                let k_embed = rope(&k.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
2193                q_embeds.push(q_embed);
2194                k_embeds.push(k_embed);
2195            }
2196            Ok((Tensor::cat(&q_embeds, 0)?, Tensor::cat(&k_embeds, 0)?))
2197        }
2198    }
2199}
2200
2201#[derive(Debug, Clone, Copy, PartialEq, Deserialize, Serialize, Default)]
2202#[serde(rename_all = "lowercase")]
2203pub enum Activation {
2204    #[default]
2205    #[serde(alias = "gelu")]
2206    Gelu,
2207    #[serde(alias = "gelu_new")]
2208    NewGelu,
2209    Relu,
2210    Relu2,
2211    Relu6,
2212    Silu,
2213    Sigmoid,
2214    HardSigmoid,
2215    Swiglu,
2216    Swish,
2217    HardSwish,
2218    Elu(f64),
2219    LeakyRelu(f64),
2220    #[serde(alias = "gelu_pytorch_tanh")]
2221    GeluPytorchTanh,
2222    QuickGelu,
2223}
2224
2225impl Module for Activation {
2226    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2227        match self {
2228            Self::Gelu => xs.gelu_erf(),
2229            // https://github.com/huggingface/transformers/blob/12f043eaeaabfef6f6efea411d98e6f6d3c094b7/src/transformers/activations.py#L49-L78
2230            Self::NewGelu => xs.gelu(),
2231            Self::Relu => xs.relu(),
2232            Self::Relu2 => xs.relu()?.sqr(),
2233            Self::Relu6 => xs.clamp(0f32, 6f32),
2234            Self::Silu => xs.silu(),
2235            Self::Sigmoid => candle_nn::ops::sigmoid(xs),
2236            Self::HardSigmoid => candle_nn::ops::hard_sigmoid(xs),
2237            Self::Swiglu => candle_nn::ops::swiglu(xs),
2238            Self::Swish => xs * candle_nn::ops::sigmoid(xs)?,
2239            Self::HardSwish => xs * candle_nn::ops::hard_sigmoid(xs)?,
2240            &Self::Elu(alpha) => xs.elu(alpha),
2241            &Self::LeakyRelu(negative_slope) => candle_nn::ops::leaky_relu(xs, negative_slope),
2242            Self::GeluPytorchTanh => xs.gelu(),
2243            Self::QuickGelu => xs * candle_nn::ops::sigmoid(&(xs * 1.702f64)?),
2244        }
2245    }
2246}
2247
2248impl TryInto<candle_nn::Activation> for Activation {
2249    type Error = candle_core::Error;
2250
2251    fn try_into(self) -> Result<candle_nn::Activation> {
2252        match self {
2253            Self::Gelu => Ok(candle_nn::Activation::Gelu),
2254            Self::Relu => Ok(candle_nn::Activation::Relu),
2255            Self::Silu => Ok(candle_nn::Activation::Silu),
2256            Self::NewGelu => Ok(candle_nn::Activation::NewGelu),
2257            Self::Relu2 => Ok(candle_nn::Activation::Relu2),
2258            Self::Relu6 => Ok(candle_nn::Activation::Relu6),
2259            Self::Sigmoid => Ok(candle_nn::Activation::Sigmoid),
2260            Self::HardSigmoid => Ok(candle_nn::Activation::HardSigmoid),
2261            Self::Swiglu => Ok(candle_nn::Activation::Swiglu),
2262            Self::Swish => Ok(candle_nn::Activation::Swish),
2263            Self::HardSwish => Ok(candle_nn::Activation::HardSwish),
2264            Self::Elu(x) => Ok(candle_nn::Activation::Elu(x)),
2265            Self::LeakyRelu(x) => Ok(candle_nn::Activation::LeakyRelu(x)),
2266            Self::GeluPytorchTanh => Ok(candle_nn::Activation::GeluPytorchTanh),
2267            Self::QuickGelu => candle_core::bail!("No mapping to candle_nn for QuickGelu"),
2268        }
2269    }
2270}
2271
2272#[derive(Debug, Clone, Copy, PartialEq, Eq)]
2273pub struct Conv3dConfig {
2274    pub padding: usize,
2275    pub stride: usize,
2276    pub dilation: usize,
2277    pub groups: usize,
2278}
2279
2280impl Default for Conv3dConfig {
2281    fn default() -> Self {
2282        Self {
2283            padding: 0,
2284            stride: 1,
2285            dilation: 1,
2286            groups: 1,
2287        }
2288    }
2289}
2290
2291pub struct Conv3dNoBias {
2292    conv2d_1: Conv2d,
2293    conv2d_2: Conv2d,
2294}
2295
2296impl Conv3dNoBias {
2297    pub fn new(
2298        in_channels: usize,
2299        out_channels: usize,
2300        kernel_sizes: [usize; 3],
2301        cfg: Conv3dConfig,
2302        vb: ShardedVarBuilder,
2303    ) -> Result<Self> {
2304        let ws = vb.get(
2305            (
2306                out_channels,
2307                in_channels / cfg.groups,
2308                kernel_sizes[0],
2309                kernel_sizes[1],
2310                kernel_sizes[2],
2311            ),
2312            "weight",
2313        )?;
2314
2315        // Split on temporal dimension
2316        // https://github.com/pytorch/pytorch/issues/139066
2317
2318        let w1 = ws.i((.., .., 0, .., ..))?;
2319        let w2 = ws.i((.., .., 1, .., ..))?;
2320
2321        let cfg = Conv2dConfig {
2322            padding: cfg.padding,
2323            stride: cfg.stride,
2324            dilation: cfg.dilation,
2325            groups: cfg.groups,
2326            cudnn_fwd_algo: None,
2327        };
2328
2329        Ok(Self {
2330            conv2d_1: Conv2d::new(w1.contiguous()?, None, cfg),
2331            conv2d_2: Conv2d::new(w2.contiguous()?, None, cfg),
2332        })
2333    }
2334
2335    pub fn weight(&self) -> Result<Tensor> {
2336        let w1 = self.conv2d_1.weight().clone().unsqueeze(2)?;
2337        let w2 = self.conv2d_2.weight().clone().unsqueeze(2)?;
2338        Tensor::cat(&[w1, w2], 2)
2339    }
2340}
2341
2342impl Module for Conv3dNoBias {
2343    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2344        let xs1 = xs.i((.., .., 0, .., ..))?;
2345        let xs2 = xs.i((.., .., 1, .., ..))?;
2346
2347        (Convolution.forward_2d(&self.conv2d_1, &xs1)?
2348            + Convolution.forward_2d(&self.conv2d_2, &xs2)?)?
2349        .unsqueeze(2)
2350    }
2351}
2352
2353pub trait TensorInfExtend {
2354    fn is_inf(&self) -> Result<Self>
2355    where
2356        Self: Sized;
2357    fn any(&self) -> Result<bool>;
2358}
2359
2360impl TensorInfExtend for Tensor {
2361    fn is_inf(&self) -> Result<Self> {
2362        self.broadcast_eq(&Tensor::new(f32::INFINITY, self.device())?.to_dtype(self.dtype())?)
2363    }
2364
2365    fn any(&self) -> Result<bool> {
2366        let sum = self.sum_all()?;
2367        match self.dtype() {
2368            DType::U8 => Ok(sum.to_scalar::<u8>()? == 0),
2369            DType::U32 => Ok(sum.to_scalar::<u32>()? == 0),
2370            DType::I16 => Ok(sum.to_scalar::<i16>()? == 0),
2371            DType::I32 => Ok(sum.to_scalar::<i32>()? == 0),
2372            DType::I64 => Ok(sum.to_scalar::<i64>()? == 0),
2373            DType::F16 => Ok(sum.to_scalar::<half::f16>()? == half::f16::from_f32_const(0.)),
2374            DType::BF16 => Ok(sum.to_scalar::<half::bf16>()? == half::bf16::from_f32_const(0.)),
2375            DType::F32 => Ok(sum.to_scalar::<f32>()? == 0.),
2376            DType::F64 => Ok(sum.to_scalar::<f64>()? == 0.),
2377            DType::F8E4M3 => Ok(sum.to_scalar::<F8E4M3>()? == F8E4M3::ZERO),
2378            DType::F4 | DType::F6E3M2 | DType::F6E2M3 | DType::F8E8M0 => {
2379                candle_core::bail!("f4/f6e3m2/f6e2m3/f8e8m0 tensors are not supported with .any")
2380            }
2381        }
2382    }
2383}
2384
2385pub fn clamp_for_f16(xs: &Tensor) -> Result<Tensor> {
2386    let mut max = match xs.dtype() {
2387        DType::U8 => u8::MAX as f32 - 1000.,
2388        DType::U32 => u32::MAX as f32 - 1000.,
2389        DType::I16 => i16::MAX as f32 - 1000.,
2390        DType::I32 => i32::MAX as f32 - 1000.,
2391        DType::I64 => i64::MAX as f32 - 1000.,
2392        DType::F16 => half::f16::MAX.to_f32_const() - 1000.,
2393        DType::BF16 => half::bf16::MAX.to_f32_const() - 1000.,
2394        DType::F32 => f32::MAX - 1000.,
2395        DType::F64 => f64::MAX as f32 - 1000.,
2396        DType::F8E4M3 => F8E4M3::MAX.to_f32() - 1000.,
2397        DType::F4 | DType::F6E3M2 | DType::F6E2M3 | DType::F8E8M0 => {
2398            candle_core::bail!("f4/f6e3m2/f6e2m3/f8e8m0 tensors are not supported with .any")
2399        }
2400    };
2401    if xs.is_inf()?.any()? {
2402        max -= 1000.;
2403    }
2404    xs.clamp(-max, max)
2405}
2406
2407pub struct FloatInfo {
2408    /// Minimum representable value.
2409    pub min: f64,
2410    /// Maximum representable value.
2411    pub max: f64,
2412    /// The difference between 1.0 and the next smallest representable float larger than 1.0.
2413    pub eps: f64,
2414    pub dtype: DType,
2415}
2416
2417pub trait GetFloatInfo {
2418    fn finfo(&self) -> Result<FloatInfo>;
2419}
2420
2421impl GetFloatInfo for DType {
2422    fn finfo(&self) -> Result<FloatInfo> {
2423        let finfo = match self {
2424            Self::BF16 => FloatInfo {
2425                min: bf16::MIN.to_f64(),
2426                max: bf16::MAX.to_f64(),
2427                eps: bf16::EPSILON.to_f64(),
2428                dtype: DType::BF16,
2429            },
2430            Self::F16 => FloatInfo {
2431                min: f16::MIN.to_f64(),
2432                max: f16::MAX.to_f64(),
2433                eps: f16::EPSILON.to_f64(),
2434                dtype: DType::F16,
2435            },
2436            Self::F32 => FloatInfo {
2437                min: f32::MIN as f64,
2438                max: f32::MAX as f64,
2439                eps: f32::EPSILON as f64,
2440                dtype: DType::F32,
2441            },
2442            Self::F64 => FloatInfo {
2443                min: f64::MIN,
2444                max: f64::MAX,
2445                eps: f64::EPSILON,
2446                dtype: DType::F64,
2447            },
2448            Self::F8E4M3 => FloatInfo {
2449                min: F8E4M3::MIN.to_f64(),
2450                max: F8E4M3::MAX.to_f64(),
2451                eps: F8E4M3::EPSILON.to_f64(),
2452                dtype: DType::F8E4M3,
2453            },
2454            other => {
2455                candle_core::bail!("Expected a float type for `GetFloatInfo`, got {other:?}");
2456            }
2457        };
2458        Ok(finfo)
2459    }
2460}
2461
2462#[derive(Clone)]
2463pub struct Mlp {
2464    pub gate: Arc<dyn QuantMethod>,
2465    pub up: Arc<dyn QuantMethod>,
2466    pub down: Arc<dyn QuantMethod>,
2467    act: Activation,
2468    params: Vec<usize>,
2469}
2470
2471impl Mlp {
2472    pub fn new(
2473        vb: ShardedVarBuilder,
2474        hidden_size: usize,
2475        intermediate_size: usize,
2476        quantization_config: &Option<QuantizedConfig>,
2477        hidden_act: Activation,
2478        comm: &Arc<mistralrs_quant::Comm>,
2479    ) -> Result<Self> {
2480        Ok(Self {
2481            gate: ColumnParallelLayer::new(
2482                hidden_size,
2483                intermediate_size,
2484                quantization_config,
2485                false,
2486                comm,
2487                vb.pp("gate_proj"),
2488            )?,
2489            up: ColumnParallelLayer::new(
2490                hidden_size,
2491                intermediate_size,
2492                quantization_config,
2493                false,
2494                comm,
2495                vb.pp("up_proj"),
2496            )?,
2497            down: RowParallelLayer::new(
2498                intermediate_size,
2499                hidden_size,
2500                quantization_config,
2501                false,
2502                comm,
2503                vb.pp("down_proj"),
2504            )?,
2505            act: hidden_act,
2506            params: vec![hidden_size, intermediate_size],
2507        })
2508    }
2509
2510    pub fn new_merged(
2511        vb: ShardedVarBuilder,
2512        hidden_size: usize,
2513        intermediate_size: usize,
2514        chunks: usize,
2515        quantization_config: &Option<QuantizedConfig>,
2516        hidden_act: Activation,
2517        comm: &Arc<mistralrs_quant::Comm>,
2518    ) -> Result<Self> {
2519        assert!(chunks == 2, "Only gate_up_proj merge is supported!");
2520        let gate_up_projs = ColumnParallelLayer::new_merged(
2521            hidden_size,
2522            intermediate_size * 2,
2523            2,
2524            quantization_config,
2525            false,
2526            comm,
2527            vb.pp("gate_up_proj"),
2528        )?;
2529
2530        Ok(Self {
2531            gate: gate_up_projs[0].to_owned(),
2532            up: gate_up_projs[1].to_owned(),
2533            down: RowParallelLayer::new(
2534                intermediate_size,
2535                hidden_size,
2536                quantization_config,
2537                false,
2538                comm,
2539                vb.pp("down_proj"),
2540            )?,
2541            act: hidden_act,
2542            params: vec![hidden_size, intermediate_size],
2543        })
2544    }
2545
2546    pub fn replicate(
2547        params: &[usize],
2548        vb: ShardedVarBuilder,
2549        act: Activation,
2550        comm: &Arc<mistralrs_quant::Comm>,
2551    ) -> Result<Self> {
2552        Self::new(vb, params[0], params[1], &None, act, comm)
2553    }
2554
2555    pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2556        let original_dtype = xs.dtype();
2557        let mut xs = xs.clone();
2558        if let Some(t) = self.gate.quantized_act_type() {
2559            xs = xs.to_dtype(t)?;
2560        }
2561        let lhs = self.gate.forward(&xs)?;
2562        let rhs = self.up.forward(&xs)?;
2563        let mut res = self
2564            .down
2565            .forward(&crate::ops::mul_and_act(&lhs, &rhs, self.act)?)?;
2566        if self.gate.quantized_act_type().is_some() {
2567            res = res.to_dtype(original_dtype)?;
2568        }
2569        Ok(res)
2570    }
2571}
2572
2573impl AnyMoeTrainableLayer for Mlp {}
2574
2575impl MlpLayer for Mlp {
2576    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2577        let original_dtype = xs.dtype();
2578        let mut xs = xs.clone();
2579        if let Some(t) = self.gate.quantized_act_type() {
2580            xs = xs.to_dtype(t)?;
2581        }
2582        let lhs = MatMul.qmethod_matmul(&xs, &*self.gate)?;
2583        let rhs = MatMul.qmethod_matmul(&xs, &*self.up)?;
2584        let mut res =
2585            MatMul.qmethod_matmul(&crate::ops::mul_and_act(&lhs, &rhs, self.act)?, &*self.down)?;
2586        if self.gate.quantized_act_type().is_some() {
2587            res = res.to_dtype(original_dtype)?;
2588        }
2589        Ok(res)
2590    }
2591    fn get_isq_layers(&mut self) -> Vec<&mut Arc<dyn QuantMethod>> {
2592        vec![&mut self.gate, &mut self.up, &mut self.down]
2593    }
2594    fn clone(&self) -> Box<dyn MlpLayer> {
2595        Box::new(Clone::clone(self))
2596    }
2597    fn get_params(&self) -> &[usize] {
2598        &self.params
2599    }
2600    fn hidden_act(&self) -> Activation {
2601        self.act
2602    }
2603    // gate, up, down
2604    fn new_added_delta(&self, deltas: Vec<Option<Tensor>>) -> Result<Box<dyn MlpLayer>> {
2605        let gate = if let Some(ref delta) = deltas[0] {
2606            self.gate.add_delta_w(delta)?
2607        } else {
2608            self.gate.clone()
2609        };
2610        let up = if let Some(ref delta) = deltas[1] {
2611            self.up.add_delta_w(delta)?
2612        } else {
2613            self.up.clone()
2614        };
2615        let down = if let Some(ref delta) = deltas[2] {
2616            self.down.add_delta_w(delta)?
2617        } else {
2618            self.down.clone()
2619        };
2620
2621        Ok(Box::new(Self {
2622            gate,
2623            up,
2624            down,
2625            act: self.act,
2626            params: self.params.clone(),
2627        }))
2628    }
2629
2630    fn dtype_device(&self) -> (DType, Device) {
2631        self.gate.dtype_and_device()
2632    }
2633}
2634
2635pub struct AvgPool2d {
2636    kernel_size: usize,
2637    stride: usize,
2638}
2639
2640impl AvgPool2d {
2641    pub fn new(kernel_size: usize, stride: usize) -> Self {
2642        Self {
2643            kernel_size,
2644            stride,
2645        }
2646    }
2647
2648    pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2649        xs.avg_pool2d_with_stride(self.kernel_size, self.stride)
2650    }
2651}
2652
2653/// Applies 2D reflection padding to a tensor of shape (N, C, H, W).
2654///
2655/// The `padding` argument is a 4-tuple (pad_left, pad_right, pad_top, pad_bottom).
2656/// For left padding, it reflects the values from column 1 up to pad_left (in reverse order);
2657/// for right padding, it reflects from the second-to-last column backwards, and similarly for
2658/// vertical (height) padding.
2659pub struct ReflectionPad2d {
2660    padding: (usize, usize, usize, usize),
2661}
2662
2663impl ReflectionPad2d {
2664    pub fn new(padding: (usize, usize, usize, usize)) -> Self {
2665        Self { padding }
2666    }
2667}
2668
2669impl Module for ReflectionPad2d {
2670    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2671        let (pad_left, pad_right, pad_top, pad_bottom) = self.padding;
2672
2673        let (_n, _c, h, w) = xs.dims4()?;
2674
2675        // --- Horizontal Padding (along width, axis = 3) ---
2676        // For left padding, we reflect columns 1..=pad_left (in reverse order).
2677        let left_pad = if pad_left > 0 {
2678            // Create indices: [pad_left, pad_left-1, ..., 1]
2679            let indices: Vec<i64> = (1..=pad_left as i64).rev().collect();
2680            Some(xs.index_select(&Tensor::new(indices, &Device::Cpu)?, 3)?)
2681        } else {
2682            None
2683        };
2684
2685        // For right padding, we reflect from the right side (excluding the last column).
2686        let right_pad = if pad_right > 0 {
2687            // For pad_right == 2, generate indices: [w-2, w-3, ... , w-1-pad_right]
2688            let start = w as i64 - 2;
2689            let indices: Vec<i64> = (0..pad_right as i64).map(|i| start - i).collect();
2690            Some(xs.index_select(&Tensor::new(indices, &Device::Cpu)?, 3)?)
2691        } else {
2692            None
2693        };
2694
2695        // Concatenate horizontally (along width, dim=3)
2696        let x_padded_width = match (left_pad, right_pad) {
2697            (Some(l), Some(r)) => Tensor::cat(&[l, xs.clone(), r], 3)?,
2698            (Some(l), None) => Tensor::cat(&[l, xs.clone()], 3)?,
2699            (None, Some(r)) => Tensor::cat(&[xs.clone(), r], 3)?,
2700            (None, None) => xs.clone(),
2701        };
2702
2703        // --- Vertical Padding (along height, axis = 2) ---
2704        // For top padding, reflect rows 1..=pad_top (in reverse order)
2705        let top_pad = if pad_top > 0 {
2706            let indices: Vec<i64> = (1..=pad_top as i64).rev().collect();
2707            Some(x_padded_width.index_select(&Tensor::new(indices, &Device::Cpu)?, 2)?)
2708        } else {
2709            None
2710        };
2711
2712        // For bottom padding, reflect from the bottom (excluding the last row)
2713        let bottom_pad = if pad_bottom > 0 {
2714            let start = h as i64 - 2;
2715            let indices: Vec<i64> = (0..pad_bottom as i64).map(|i| start - i).collect();
2716            Some(x_padded_width.index_select(&Tensor::new(indices, &Device::Cpu)?, 2)?)
2717        } else {
2718            None
2719        };
2720
2721        // Concatenate vertically (along height, dim=2)
2722        let x_padded = match (top_pad, bottom_pad) {
2723            (Some(t), Some(b)) => Tensor::cat(&[t, x_padded_width, b], 2)?,
2724            (Some(t), None) => Tensor::cat(&[t, x_padded_width], 2)?,
2725            (None, Some(b)) => Tensor::cat(&[x_padded_width, b], 2)?,
2726            (None, None) => x_padded_width,
2727        };
2728
2729        Ok(x_padded)
2730    }
2731}
2732
2733pub struct ScaledEmbedding {
2734    scale: f64,
2735    pub embedding: Tensor,
2736}
2737
2738impl ScaledEmbedding {
2739    pub fn new(scale: f64, embedding: Embedding) -> Self {
2740        Self {
2741            scale,
2742            embedding: embedding.embeddings().clone(),
2743        }
2744    }
2745
2746    pub fn embeddings(&self) -> &Tensor {
2747        &self.embedding
2748    }
2749}
2750
2751impl Module for ScaledEmbedding {
2752    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2753        let embedding = Embedding::new(self.embedding.clone(), self.embedding.dim(D::Minus1)?);
2754        xs.apply(&embedding)? * self.scale
2755    }
2756}