mistralrs_core/
layers.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{f32::consts::PI, ops::Mul, str::FromStr, sync::Arc};
4
5use candle_core::{
6    quantized::{QMatMul, QTensor},
7    Context, DType, Device, IndexOp, Result, Tensor, D,
8};
9use candle_nn::{
10    BatchNorm, BatchNormConfig, Conv1d, Conv1dConfig, Conv2d, Conv2dConfig, Embedding, GroupNorm,
11    LayerNorm, LayerNormConfig, Linear, Module,
12};
13use float8::F8E4M3;
14use half::{bf16, f16};
15use mistralrs_quant::{
16    AfqLayer, ColumnParallelLayer, Convolution, QuantMethod, QuantizedConfig, RowParallelLayer,
17    ShardedVarBuilder,
18};
19use serde::{Deserialize, Serialize};
20
21pub use crate::attention::Sdpa;
22pub use crate::layers_masker::CausalMasker;
23pub use crate::layers_utils::repeat_kv;
24use crate::{
25    amoe::{AnyMoeTrainableLayer, MlpLayer},
26    embedding_models::embedding_gemma::EmbeddingGemmaConfig,
27    gguf::Content,
28    models::{llama, smollm3},
29    ops::SplitOp,
30    vision_models::{
31        gemma3::config::Gemma3TextConfig,
32        gemma3n::config::Gemma3nTextConfig,
33        llama4,
34        mllama::{MLlamaRopeScaling, MLlamaRopeType, MLlamaTextConfig},
35        phi4::Phi4MMConfig,
36    },
37};
38
39pub use mistralrs_quant::MatMul;
40
41pub fn embedding(
42    in_size: usize,
43    out_size: usize,
44    vb: ShardedVarBuilder,
45    config: &Option<QuantizedConfig>,
46) -> Result<Embedding> {
47    // AFQ quantized applies quantization to the embeddings.
48    let embeddings = if let Some(QuantizedConfig::Afq { .. }) = config {
49        let afq_layer =
50            AfqLayer::afq_linear_b(out_size, in_size, config.as_ref().unwrap(), false, vb)?;
51        afq_layer.dequantize_w()?
52    } else {
53        vb.get_with_hints((in_size, out_size), "weight", Default::default())?
54    };
55    Ok(Embedding::new(embeddings, out_size))
56}
57
58pub fn layer_norm<C: Into<LayerNormConfig>>(
59    size: usize,
60    config: C,
61    vb: ShardedVarBuilder,
62) -> Result<LayerNorm> {
63    let config = config.into();
64    let weight = vb.get(size, "weight")?;
65    if config.affine {
66        let bias = vb.get(size, "bias")?;
67        Ok(LayerNorm::new(weight, bias, config.eps))
68    } else {
69        Ok(LayerNorm::new_no_bias(weight, config.eps))
70    }
71}
72
73pub fn batch_norm<C: Into<BatchNormConfig>>(
74    num_features: usize,
75    config: C,
76    vb: ShardedVarBuilder,
77) -> Result<BatchNorm> {
78    let config = config.into();
79    if config.eps < 0. {
80        candle_core::bail!("batch-norm eps cannot be negative {}", config.eps)
81    }
82    let running_mean = vb.get(num_features, "running_mean")?;
83    let running_var = vb.get(num_features, "running_var")?;
84
85    if config.affine {
86        let weight = vb.get(num_features, "weight")?;
87        let bias = vb.get(num_features, "bias")?;
88        BatchNorm::new(
89            num_features,
90            running_mean,
91            running_var,
92            weight,
93            bias,
94            config.eps,
95        )
96    } else {
97        BatchNorm::new_no_bias(num_features, running_mean, running_var, config.eps)
98    }
99}
100
101pub fn group_norm(
102    num_groups: usize,
103    num_channels: usize,
104    eps: f64,
105    vb: ShardedVarBuilder,
106) -> Result<GroupNorm> {
107    let weight = vb.get(num_channels, "weight")?;
108    let bias = vb.get(num_channels, "bias")?;
109    GroupNorm::new(weight, bias, num_channels, num_groups, eps)
110}
111
112pub fn conv2d(
113    in_channels: usize,
114    out_channels: usize,
115    kernel_size: usize,
116    cfg: Conv2dConfig,
117    vb: ShardedVarBuilder,
118) -> Result<Conv2d> {
119    let ws = vb.get(
120        (
121            out_channels,
122            in_channels / cfg.groups,
123            kernel_size,
124            kernel_size,
125        ),
126        "weight",
127    )?;
128    let bs = vb.get(out_channels, "bias")?;
129    Ok(Conv2d::new(ws, Some(bs), cfg))
130}
131
132pub fn conv2d_no_bias(
133    in_channels: usize,
134    out_channels: usize,
135    kernel_size: usize,
136    cfg: Conv2dConfig,
137    vb: ShardedVarBuilder,
138) -> Result<Conv2d> {
139    let ws = vb.get(
140        (
141            out_channels,
142            in_channels / cfg.groups,
143            kernel_size,
144            kernel_size,
145        ),
146        "weight",
147    )?;
148    Ok(Conv2d::new(ws, None, cfg))
149}
150
151pub fn conv1d(
152    in_channels: usize,
153    out_channels: usize,
154    kernel_size: usize,
155    cfg: Conv1dConfig,
156    vb: ShardedVarBuilder,
157) -> Result<Conv1d> {
158    let ws = vb.get(
159        (out_channels, in_channels / cfg.groups, kernel_size),
160        "weight",
161    )?;
162    let bs = vb.get(out_channels, "bias")?;
163    Ok(Conv1d::new(ws, Some(bs), cfg))
164}
165
166pub fn conv1d_no_bias(
167    in_channels: usize,
168    out_channels: usize,
169    kernel_size: usize,
170    cfg: Conv1dConfig,
171    vb: ShardedVarBuilder,
172) -> Result<Conv1d> {
173    let ws = vb.get(
174        (out_channels, in_channels / cfg.groups, kernel_size),
175        "weight",
176    )?;
177    Ok(Conv1d::new(ws, None, cfg))
178}
179
180pub fn linear(in_dim: usize, out_dim: usize, vb: ShardedVarBuilder) -> Result<Linear> {
181    let ws = vb.get((out_dim, in_dim), "weight")?;
182    let bs = vb.get(out_dim, "bias")?;
183    Ok(Linear::new(ws, Some(bs)))
184}
185
186pub fn linear_no_bias(in_dim: usize, out_dim: usize, vb: ShardedVarBuilder) -> Result<Linear> {
187    let ws = vb.get((out_dim, in_dim), "weight")?;
188    Ok(Linear::new(ws, None))
189}
190
191pub fn linear_b(
192    in_dim: usize,
193    out_dim: usize,
194    bias: bool,
195    vb: ShardedVarBuilder,
196) -> Result<Linear> {
197    if bias {
198        linear(in_dim, out_dim, vb)
199    } else {
200        linear_no_bias(in_dim, out_dim, vb)
201    }
202}
203
204#[derive(Debug, Clone)]
205pub struct RmsNorm {
206    eps: f64,
207    weight: Tensor,
208}
209
210impl RmsNorm {
211    pub fn new(size: usize, eps: f64, vb: ShardedVarBuilder) -> Result<Self> {
212        let w = vb.get(size, "weight")?;
213        Ok(Self { eps, weight: w })
214    }
215
216    /// Gemma uses weight + 1.0
217    pub fn new_gemma(size: usize, eps: f64, vb: ShardedVarBuilder) -> Result<Self> {
218        let w = vb.get(size, "weight")?;
219        let w = (w + 1.0)?;
220        Ok(Self { eps, weight: w })
221    }
222
223    /// Gemma 3n uses weight
224    pub fn new_gemma_3n(
225        size: usize,
226        eps: f64,
227        with_scale: bool,
228        vb: ShardedVarBuilder,
229    ) -> Result<Self> {
230        let w = if with_scale {
231            vb.get(size, "weight")?
232        } else {
233            Tensor::ones(size, vb.dtype(), vb.device())?
234        };
235        Ok(Self { eps, weight: w })
236    }
237
238    /// Gemma uses weight + 1.0. Undo for UQFF generation.
239    pub fn undo_gemma(&self) -> Result<Self> {
240        Ok(Self {
241            eps: self.eps,
242            weight: (&self.weight - 1.0)?,
243        })
244    }
245
246    pub fn from_w(w: Tensor, eps: f64) -> Result<Self> {
247        Ok(Self { eps, weight: w })
248    }
249
250    pub fn weight(&self) -> &Tensor {
251        &self.weight
252    }
253}
254
255impl Module for RmsNorm {
256    fn forward(&self, x: &Tensor) -> Result<Tensor> {
257        candle_nn::ops::rms_norm(&x.contiguous()?, &self.weight, self.eps as f32)
258    }
259}
260
261#[derive(Debug, Clone)]
262pub struct F32RmsNorm {
263    w: Tensor,
264    eps: f64,
265}
266
267impl F32RmsNorm {
268    pub fn new(size: usize, eps: f64, vb: ShardedVarBuilder) -> Result<Self> {
269        Ok(Self {
270            w: vb.get((size,), "weight")?,
271            eps,
272        })
273    }
274
275    pub fn weight(&self) -> &Tensor {
276        &self.w
277    }
278}
279
280impl Module for F32RmsNorm {
281    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
282        let initial_type = xs.dtype();
283        let mut xs = xs.to_dtype(DType::F32)?;
284        let var = xs.powf(2.)?.mean_keepdim(D::Minus1)?;
285        xs = xs.broadcast_mul(&(&var + self.eps)?.recip()?.sqrt()?)?;
286        xs.to_dtype(initial_type)?.broadcast_mul(&self.w)
287    }
288}
289
290#[derive(Debug, Clone)]
291pub struct QRmsNorm {
292    eps: f64,
293    weight: Tensor,
294}
295
296impl QRmsNorm {
297    pub fn new(scale: QTensor, eps: f32) -> Result<Self> {
298        let scale = scale.dequantize(&scale.device())?;
299        Ok(Self {
300            eps: eps as f64,
301            weight: scale,
302        })
303    }
304
305    pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
306        candle_nn::ops::rms_norm(&x.contiguous()?, &self.weight, self.eps as f32)
307    }
308}
309
310/// RoPE supporting LongRope
311#[derive(Debug, Clone)]
312pub struct PhiRotaryEmbedding {
313    short_sin: Tensor,
314    short_cos: Tensor,
315    long_cos: Option<Tensor>,
316    long_sin: Option<Tensor>,
317    original_max_position_embeddings: usize,
318}
319
320#[derive(Debug, Clone, Deserialize, Serialize)]
321#[serde(rename_all = "lowercase")]
322pub enum ScaledRopeType {
323    #[serde(alias = "su")]
324    #[serde(alias = "longrope")]
325    Su,
326    #[serde(alias = "yarn")]
327    Yarn,
328    #[serde(alias = "dynamic")]
329    Dynamic,
330    #[serde(alias = "linear")]
331    Linear,
332}
333
334impl FromStr for ScaledRopeType {
335    type Err = candle_core::Error;
336    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
337        match s {
338            "su" | "longrope" => Ok(Self::Su),
339            "yarn" => Ok(Self::Yarn),
340            "linear" => Ok(Self::Linear),
341            "dynamic" => Ok(Self::Dynamic),
342            _ => Err(candle_core::Error::Msg(
343                "Expected either `su` or `yarn` scaled RoPE type.".to_string(),
344            )),
345        }
346    }
347}
348
349#[derive(Debug, Clone, Deserialize, Serialize)]
350#[serde(untagged)]
351pub enum PhiRopeScalingConfig {
352    Classic {
353        short_factor: Vec<f64>,
354        long_factor: Vec<f64>,
355        #[serde(rename = "type")]
356        scaling_type: ScaledRopeType,
357    },
358    Scaled {
359        short_factor: Vec<f64>,
360        long_factor: Vec<f64>,
361        #[serde(rename = "type")]
362        scaling_type: ScaledRopeType,
363        long_mscale: f64,
364        short_mscale: f64,
365    },
366}
367
368pub struct PhiRopeConfig {
369    pub rope_scaling: Option<PhiRopeScalingConfig>,
370    pub max_position_embeddings: usize,
371    pub original_max_position_embeddings: usize,
372    pub rope_theta: f64,
373    pub head_dim: usize,
374    pub partial_rotary_factor: Option<f64>,
375}
376
377impl PhiRotaryEmbedding {
378    fn new_classic_scaled(
379        short_factor: &[f64],
380        long_factor: &[f64],
381        scaling_type: &ScaledRopeType,
382        cfg: &PhiRopeConfig,
383        dtype: DType,
384        dev: &Device,
385    ) -> Result<Self> {
386        let max_seq_len = cfg.max_position_embeddings;
387        let dim = (cfg.head_dim as f64 * cfg.partial_rotary_factor.unwrap_or(1.)) as usize;
388
389        // Calculate scale
390        let scale =
391            cfg.max_position_embeddings as f64 / cfg.original_max_position_embeddings as f64;
392        let scaling_factor = if scale <= 1.0 {
393            1.0
394        } else {
395            match scaling_type {
396                ScaledRopeType::Su => {
397                    (1.0 + scale.ln() / (cfg.original_max_position_embeddings as f64).ln()).sqrt()
398                }
399                ScaledRopeType::Yarn => 0.1 * scale.ln() + 1.0,
400                _ => candle_core::bail!("Expected either `su` or `yarn` RoPE"),
401            }
402        };
403
404        // Calculate inv freqs for short, long
405        let inv_freq_long = (0..dim)
406            .step_by(2)
407            .enumerate()
408            .map(|(k, i)| {
409                (1f64 / (long_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64))) as f32
410            })
411            .collect::<Vec<_>>();
412        let inv_freq_short = (0..dim)
413            .step_by(2)
414            .enumerate()
415            .map(|(k, i)| {
416                (1f64 / (short_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64))) as f32
417            })
418            .collect::<Vec<_>>();
419        let inv_freq_len = inv_freq_long.len();
420
421        let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
422            .to_dtype(DType::F32)?
423            .reshape((max_seq_len, 1))?;
424
425        // Calculate sin,cos for long
426        let inv_freq_long = Tensor::from_vec(inv_freq_long, (1, inv_freq_len), dev)?;
427        let freqs_long = t.matmul(&inv_freq_long)?;
428        let long_sin = freqs_long.sin()?.mul(scaling_factor)?.to_dtype(dtype)?;
429        let long_cos = freqs_long.cos()?.mul(scaling_factor)?.to_dtype(dtype)?;
430
431        // Calculate sin,cos for short
432        let inv_freq_short =
433            Tensor::from_vec(inv_freq_short, (1, inv_freq_len), dev)?.to_dtype(DType::F32)?;
434        let freqs_short = t.matmul(&inv_freq_short)?;
435        let short_sin = freqs_short.sin()?.mul(scaling_factor)?.to_dtype(dtype)?;
436        let short_cos = freqs_short.cos()?.mul(scaling_factor)?.to_dtype(dtype)?;
437
438        Ok(Self {
439            short_cos,
440            short_sin,
441            long_cos: Some(long_cos),
442            long_sin: Some(long_sin),
443            original_max_position_embeddings: cfg.original_max_position_embeddings,
444        })
445    }
446
447    fn new_unscaled(cfg: &PhiRopeConfig, dtype: DType, dev: &Device) -> Result<Self> {
448        let max_seq_len = cfg.max_position_embeddings;
449        let dim = (cfg.head_dim as f64 * cfg.partial_rotary_factor.unwrap_or(1.)) as usize;
450
451        let inv_freq: Vec<_> = (0..dim)
452            .step_by(2)
453            .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
454            .collect();
455        let inv_freq_len = inv_freq.len();
456        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
457        let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
458            .to_dtype(DType::F32)?
459            .reshape((max_seq_len, 1))?;
460        let freqs = t.matmul(&inv_freq)?;
461        let sin = freqs.sin()?.to_dtype(dtype)?;
462        let cos = freqs.cos()?.to_dtype(dtype)?;
463        Ok(Self {
464            short_cos: cos,
465            short_sin: sin,
466            long_cos: None,
467            long_sin: None,
468            original_max_position_embeddings: cfg.original_max_position_embeddings,
469        })
470    }
471
472    #[allow(clippy::too_many_arguments)]
473    fn new_scaled(
474        short_factor: &[f64],
475        long_factor: &[f64],
476        scaling_type: &ScaledRopeType,
477        long_mscale: f64,
478        short_mscale: f64,
479        cfg: &PhiRopeConfig,
480        dtype: DType,
481        dev: &Device,
482    ) -> Result<Self> {
483        let max_seq_len = cfg.max_position_embeddings;
484        let dim = (cfg.head_dim as f64 * cfg.partial_rotary_factor.unwrap_or(1.)) as usize;
485
486        if !matches!(scaling_type, ScaledRopeType::Su) {
487            candle_core::bail!("Scaled Phi3 RoPE (non-classic scaled, with mscales) must have type `su`/`longrope`.");
488        }
489
490        if short_factor.len() != dim / 2 {
491            candle_core::bail!(
492                "Misaligned length {}, expected {} for `su`/`longrope` short rescale factors",
493                short_factor.len(),
494                dim / 2
495            );
496        }
497        if long_factor.len() != dim / 2 {
498            candle_core::bail!(
499                "Misaligned length {}, expected {} for `su`/`longrope` long rescale factors",
500                long_factor.len(),
501                dim / 2
502            );
503        }
504
505        // Short cos/sin
506        let inv_freq_short: Vec<_> = (0..dim)
507            .step_by(2)
508            .enumerate()
509            .map(|(k, i)| {
510                1f32 / (short_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32
511            })
512            .collect();
513        let inv_freq_len_short = inv_freq_short.len();
514        let inv_freq_short = Tensor::from_vec(inv_freq_short, (1, inv_freq_len_short), dev)?;
515        let t_short = Tensor::arange(0u32, max_seq_len as u32, dev)?
516            .to_dtype(DType::F32)?
517            .reshape((max_seq_len, 1))?;
518        let freqs_short = t_short.matmul(&inv_freq_short)?;
519        let sin_short = (freqs_short.sin()?.to_dtype(dtype)? * short_mscale)?;
520        let cos_short = (freqs_short.cos()?.to_dtype(dtype)? * short_mscale)?;
521
522        // Long cos/sin
523        let inv_freq_long: Vec<_> = (0..dim)
524            .step_by(2)
525            .enumerate()
526            .map(|(k, i)| {
527                1f32 / (long_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32
528            })
529            .collect();
530        let inv_freq_len_long = inv_freq_long.len();
531        let inv_freq_long = Tensor::from_vec(inv_freq_long, (1, inv_freq_len_long), dev)?;
532        let t_long = Tensor::arange(0u32, max_seq_len as u32, dev)?
533            .to_dtype(DType::F32)?
534            .reshape((max_seq_len, 1))?;
535        let freqs_long = t_long.matmul(&inv_freq_long)?;
536        let sin_long = (freqs_long.sin()?.to_dtype(dtype)? * long_mscale)?;
537        let cos_long = (freqs_long.cos()?.to_dtype(dtype)? * long_mscale)?;
538        Ok(Self {
539            short_cos: cos_short,
540            short_sin: sin_short,
541            long_cos: Some(cos_long),
542            long_sin: Some(sin_long),
543            original_max_position_embeddings: cfg.original_max_position_embeddings,
544        })
545    }
546
547    pub fn new(dtype: DType, cfg: impl Into<PhiRopeConfig>, dev: &Device) -> Result<Self> {
548        let cfg: PhiRopeConfig = cfg.into();
549
550        match &cfg.rope_scaling {
551            Some(PhiRopeScalingConfig::Classic {
552                short_factor,
553                long_factor,
554                scaling_type,
555            }) => {
556                Self::new_classic_scaled(short_factor, long_factor, scaling_type, &cfg, dtype, dev)
557            }
558
559            Some(PhiRopeScalingConfig::Scaled {
560                short_factor,
561                long_factor,
562                scaling_type,
563                long_mscale,
564                short_mscale,
565            }) => Self::new_scaled(
566                short_factor,
567                long_factor,
568                scaling_type,
569                *long_mscale,
570                *short_mscale,
571                &cfg,
572                dtype,
573                dev,
574            ),
575
576            None => Self::new_unscaled(&cfg, dtype, dev),
577        }
578    }
579
580    /// Returns (sin, cos) taking into account LongRope
581    fn get_long_or_short_sin_cos(&self, position_ids: &[usize]) -> (&Tensor, &Tensor) {
582        if self.long_cos.is_none() {
583            return (&self.short_sin, &self.short_cos);
584        }
585        let seq_len = position_ids.iter().max().unwrap() + 1;
586        if seq_len > self.original_max_position_embeddings {
587            (
588                self.long_sin.as_ref().unwrap(),
589                self.long_cos.as_ref().unwrap(),
590            )
591        } else {
592            (&self.short_sin, &self.short_cos)
593        }
594    }
595
596    pub fn forward(
597        &self,
598        q: &Tensor,
599        k: &Tensor,
600        seqlen_offsets: &[usize],
601        position_ids: &[usize],
602    ) -> Result<(Tensor, Tensor)> {
603        let (sin, cos) = self.get_long_or_short_sin_cos(position_ids);
604        let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
605
606        let rot_dim = cos.dim(D::Minus1)? * 2;
607
608        // Case for Phi 3 / Phi 4 mini
609        if rot_dim != q.dim(D::Minus1)? {
610            let rot_dim = cos.dim(D::Minus1)? * 2;
611            let q_rot = q.narrow(D::Minus1, 0, rot_dim)?;
612            let q_pass = q.narrow(D::Minus1, rot_dim, q.dim(D::Minus1)? - rot_dim)?;
613            let k_rot = k.narrow(D::Minus1, 0, rot_dim)?;
614            let k_pass = k.narrow(D::Minus1, rot_dim, k.dim(D::Minus1)? - rot_dim)?;
615
616            let (q_rot, k_rot) = if seqlen_offsets.len() == 1 {
617                let cos = cos.narrow(0, seqlen_offsets[0], seq_len)?;
618                let sin = sin.narrow(0, seqlen_offsets[0], seq_len)?;
619                let q_embed = candle_nn::rotary_emb::rope(&q_rot.contiguous()?, &cos, &sin)?;
620                let k_embed = candle_nn::rotary_emb::rope(&k_rot.contiguous()?, &cos, &sin)?;
621                (q_embed, k_embed)
622            } else {
623                let mut q_embeds = Vec::new();
624                let mut k_embeds = Vec::new();
625                for (i, offset) in seqlen_offsets.iter().enumerate() {
626                    let cos = cos.narrow(0, *offset, seq_len)?;
627                    let sin = sin.narrow(0, *offset, seq_len)?;
628                    let q_embed = candle_nn::rotary_emb::rope(
629                        &q_rot.i(i)?.unsqueeze(0)?.contiguous()?,
630                        &cos,
631                        &sin,
632                    )?;
633                    let k_embed = candle_nn::rotary_emb::rope(
634                        &k_rot.i(i)?.unsqueeze(0)?.contiguous()?,
635                        &cos,
636                        &sin,
637                    )?;
638                    q_embeds.push(q_embed);
639                    k_embeds.push(k_embed);
640                }
641                let q_rot = Tensor::cat(&q_embeds, 0)?;
642                let k_rot = Tensor::cat(&k_embeds, 0)?;
643                (q_rot, k_rot)
644            };
645
646            Ok((
647                Tensor::cat(&[q_rot, q_pass], D::Minus1)?.contiguous()?,
648                Tensor::cat(&[k_rot, k_pass], D::Minus1)?.contiguous()?,
649            ))
650        } else if seqlen_offsets.len() == 1 {
651            let cos = cos.narrow(0, seqlen_offsets[0], seq_len)?;
652            let sin = sin.narrow(0, seqlen_offsets[0], seq_len)?;
653            let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
654            let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
655            Ok((q_embed, k_embed))
656        } else {
657            let mut q_embeds = Vec::new();
658            let mut k_embeds = Vec::new();
659            for (i, offset) in seqlen_offsets.iter().enumerate() {
660                let cos = cos.narrow(0, *offset, seq_len)?;
661                let sin = sin.narrow(0, *offset, seq_len)?;
662                let q_embed =
663                    candle_nn::rotary_emb::rope(&q.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
664                let k_embed =
665                    candle_nn::rotary_emb::rope(&k.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
666                q_embeds.push(q_embed);
667                k_embeds.push(k_embed);
668            }
669            Ok((Tensor::cat(&q_embeds, 0)?, Tensor::cat(&k_embeds, 0)?))
670        }
671    }
672}
673
674/// RoPE for Llama3
675#[derive(Debug, Clone)]
676pub struct Llama3RotaryEmbedding(RotaryEmbedding);
677
678#[derive(Debug, Clone, Deserialize, Serialize, Default)]
679pub enum Llama3RopeType {
680    #[serde(rename = "llama3")]
681    Llama3,
682    #[serde(rename = "linear")]
683    Linear,
684    #[default]
685    #[serde(rename = "default")]
686    Default,
687}
688
689#[derive(Debug, Clone, Deserialize, Serialize, Default)]
690pub struct Llama3RopeConfig {
691    pub factor: f32,
692    pub low_freq_factor: Option<f32>,
693    pub high_freq_factor: Option<f32>,
694    pub original_max_position_embeddings: Option<usize>,
695    pub rope_type: Llama3RopeType,
696}
697
698fn calculate_default_inv_freq(cfg: &llama::Config) -> Vec<f32> {
699    let head_dim = cfg.hidden_size / cfg.num_attention_heads;
700    (0..head_dim)
701        .step_by(2)
702        .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
703        .collect()
704}
705
706fn calculate_default_inv_freq_llama4(cfg: &llama4::TextConfig) -> Vec<f32> {
707    let head_dim = cfg.hidden_size / cfg.num_attention_heads;
708    (0..head_dim)
709        .step_by(2)
710        .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
711        .collect()
712}
713
714// https://github.com/huggingface/transformers/blob/1392a6867f40a55dfabaf306745c67627598b1af/src/transformers/modeling_rope_utils.py#L298
715impl Llama3RotaryEmbedding {
716    pub fn new_llama3(
717        dtype: DType,
718        cfg: &llama::Config,
719        dev: &Device,
720        is_gpt_neox: bool,
721    ) -> Result<Self> {
722        match &cfg.rope_scaling {
723            None
724            | Some(Llama3RopeConfig {
725                rope_type: Llama3RopeType::Default,
726                ..
727            }) => Ok(Self(RotaryEmbedding::new(
728                cfg.rope_theta,
729                cfg.hidden_size / cfg.num_attention_heads,
730                cfg.max_position_embeddings,
731                dev,
732                is_gpt_neox,
733                dtype,
734            )?)),
735            Some(Llama3RopeConfig {
736                rope_type: Llama3RopeType::Llama3,
737                factor,
738                low_freq_factor,
739                high_freq_factor,
740                original_max_position_embeddings,
741            }) => {
742                let low_freq_factor = low_freq_factor.context("low_freq_factor is required")?;
743                let high_freq_factor = high_freq_factor.context("high_freq_factor is required")?;
744                let original_max_position_embeddings = original_max_position_embeddings
745                    .context("original_max_position_embeddings is required")?;
746
747                let low_freq_wavelen = original_max_position_embeddings as f32 / low_freq_factor;
748                let high_freq_wavelen = original_max_position_embeddings as f32 / high_freq_factor;
749
750                let inv_freq = calculate_default_inv_freq(cfg)
751                    .into_iter()
752                    .map(|freq| {
753                        let wavelen = 2. * PI / freq;
754                        if wavelen < high_freq_wavelen {
755                            freq
756                        } else if wavelen > low_freq_wavelen {
757                            freq / *factor
758                        } else {
759                            let smooth = (original_max_position_embeddings as f32 / wavelen
760                                - low_freq_factor)
761                                / (high_freq_factor - low_freq_factor);
762                            (1. - smooth) * freq / *factor + smooth * freq
763                        }
764                    })
765                    .collect::<Vec<_>>();
766                let inv_freq_len = inv_freq.len();
767                let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
768                let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
769                    .to_dtype(DType::F32)?
770                    .reshape((cfg.max_position_embeddings, 1))?;
771                let freqs = t.matmul(&inv_freq)?;
772                let sin = freqs.sin()?.to_dtype(dtype)?;
773                let cos = freqs.cos()?.to_dtype(dtype)?;
774                Ok(Self(RotaryEmbedding {
775                    sin,
776                    cos,
777                    is_gpt_neox,
778                }))
779            }
780            Some(Llama3RopeConfig {
781                rope_type: Llama3RopeType::Linear,
782                factor,
783                ..
784            }) => {
785                let inv_freq_vec = calculate_default_inv_freq(cfg)
786                    .into_iter()
787                    .map(|freq| freq / *factor)
788                    .collect::<Vec<_>>();
789                let inv_freq_len = inv_freq_vec.len();
790                let inv_freq = Tensor::from_vec(inv_freq_vec, (1, inv_freq_len), dev)?;
791                let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
792                    .to_dtype(DType::F32)?
793                    .reshape((cfg.max_position_embeddings, 1))?;
794                let freqs = t.matmul(&inv_freq)?;
795                let sin = freqs.sin()?.to_dtype(dtype)?;
796                let cos = freqs.cos()?.to_dtype(dtype)?;
797                Ok(Self(RotaryEmbedding {
798                    sin,
799                    cos,
800                    is_gpt_neox,
801                }))
802            }
803        }
804    }
805
806    pub fn new_llama4(
807        dtype: DType,
808        cfg: &llama4::TextConfig,
809        dev: &Device,
810        is_gpt_neox: bool,
811    ) -> Result<Self> {
812        match &cfg.rope_scaling {
813            None
814            | Some(Llama3RopeConfig {
815                rope_type: Llama3RopeType::Default,
816                ..
817            }) => Ok(Self(RotaryEmbedding::new(
818                cfg.rope_theta,
819                cfg.hidden_size / cfg.num_attention_heads,
820                cfg.max_position_embeddings,
821                dev,
822                is_gpt_neox,
823                dtype,
824            )?)),
825            Some(Llama3RopeConfig {
826                rope_type: Llama3RopeType::Llama3,
827                factor,
828                low_freq_factor,
829                high_freq_factor,
830                original_max_position_embeddings,
831            }) => {
832                let low_freq_factor = low_freq_factor.context("low_freq_factor is required")?;
833                let high_freq_factor = high_freq_factor.context("high_freq_factor is required")?;
834                let original_max_position_embeddings = original_max_position_embeddings
835                    .context("original_max_position_embeddings is required")?;
836
837                let low_freq_wavelen = original_max_position_embeddings as f32 / low_freq_factor;
838                let high_freq_wavelen = original_max_position_embeddings as f32 / high_freq_factor;
839
840                let inv_freq = calculate_default_inv_freq_llama4(cfg)
841                    .into_iter()
842                    .map(|freq| {
843                        let wavelen = 2. * PI / freq;
844                        if wavelen < high_freq_wavelen {
845                            freq
846                        } else if wavelen > low_freq_wavelen {
847                            freq / *factor
848                        } else {
849                            let smooth = (original_max_position_embeddings as f32 / wavelen
850                                - low_freq_factor)
851                                / (high_freq_factor - low_freq_factor);
852                            (1. - smooth) * freq / *factor + smooth * freq
853                        }
854                    })
855                    .collect::<Vec<_>>();
856                let inv_freq_len = inv_freq.len();
857                let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
858                let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
859                    .to_dtype(DType::F32)?
860                    .reshape((cfg.max_position_embeddings, 1))?;
861                let freqs = t.matmul(&inv_freq)?;
862                let sin = freqs.sin()?.to_dtype(dtype)?;
863                let cos = freqs.cos()?.to_dtype(dtype)?;
864                Ok(Self(RotaryEmbedding {
865                    sin,
866                    cos,
867                    is_gpt_neox,
868                }))
869            }
870            Some(Llama3RopeConfig {
871                rope_type: Llama3RopeType::Linear,
872                factor,
873                ..
874            }) => {
875                let inv_freq_vec = calculate_default_inv_freq_llama4(cfg)
876                    .into_iter()
877                    .map(|freq| freq / *factor)
878                    .collect::<Vec<_>>();
879                let inv_freq_len = inv_freq_vec.len();
880                let inv_freq = Tensor::from_vec(inv_freq_vec, (1, inv_freq_len), dev)?;
881                let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
882                    .to_dtype(DType::F32)?
883                    .reshape((cfg.max_position_embeddings, 1))?;
884                let freqs = t.matmul(&inv_freq)?;
885                let sin = freqs.sin()?.to_dtype(dtype)?;
886                let cos = freqs.cos()?.to_dtype(dtype)?;
887                Ok(Self(RotaryEmbedding {
888                    sin,
889                    cos,
890                    is_gpt_neox,
891                }))
892            }
893        }
894    }
895
896    pub fn new_mllama3(
897        dtype: DType,
898        cfg: &MLlamaTextConfig,
899        dev: &Device,
900        is_gpt_neox: bool,
901    ) -> Result<Self> {
902        match &cfg.rope_scaling {
903            None
904            | Some(MLlamaRopeScaling {
905                rope_type: MLlamaRopeType::Default,
906                ..
907            }) => Ok(Self(RotaryEmbedding::new(
908                cfg.rope_theta,
909                cfg.hidden_size / cfg.num_attention_heads,
910                cfg.max_position_embeddings,
911                dev,
912                is_gpt_neox,
913                dtype,
914            )?)),
915            Some(MLlamaRopeScaling {
916                rope_type: MLlamaRopeType::Llama3,
917                original_max_position_embeddings,
918                factor,
919                attention_factor: _,
920                beta_fast: _,
921                beta_slow: _,
922                short_factor: _,
923                long_factor: _,
924                low_freq_factor,
925                high_freq_factor,
926            }) => {
927                let factor = factor.context("MLlama Llama3 RoPE needs `factor` parameter.")?;
928                let low_freq_factor = low_freq_factor
929                    .context("MLlama Llama3 RoPE needs `low_freq_factor` parameter.")?;
930                let high_freq_factor = high_freq_factor
931                    .context("MLlama Llama3 RoPE needs `high_freq_factor` parameter.")?;
932
933                let low_freq_wavelen = *original_max_position_embeddings as f32 / low_freq_factor;
934                let high_freq_wavelen = *original_max_position_embeddings as f32 / high_freq_factor;
935
936                let head_dim = cfg.hidden_size / cfg.num_attention_heads;
937
938                let inv_freq = (0..head_dim)
939                    .step_by(2)
940                    .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
941                    .map(|freq| {
942                        let wavelen = 2. * PI / freq;
943                        if wavelen < high_freq_wavelen {
944                            freq
945                        } else if wavelen > low_freq_wavelen {
946                            freq / factor
947                        } else {
948                            let smooth = (*original_max_position_embeddings as f32 / wavelen
949                                - low_freq_factor)
950                                / (high_freq_factor - low_freq_factor);
951                            (1. - smooth) * freq / factor + smooth * freq
952                        }
953                    })
954                    .collect::<Vec<_>>();
955                let inv_freq_len = inv_freq.len();
956                let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
957
958                let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
959                    .to_dtype(DType::F32)?
960                    .reshape((cfg.max_position_embeddings, 1))?;
961                let freqs = t.matmul(&inv_freq)?;
962                let sin = freqs.sin()?.to_dtype(dtype)?;
963                let cos = freqs.cos()?.to_dtype(dtype)?;
964                Ok(Self(RotaryEmbedding {
965                    sin,
966                    cos,
967                    is_gpt_neox,
968                }))
969            }
970            Some(MLlamaRopeScaling {
971                rope_type: other, ..
972            }) => {
973                candle_core::bail!(
974                    "MLlama doesn't support any other RoPE type than `llama3`, got {other:?}"
975                )
976            }
977        }
978    }
979
980    pub fn forward(
981        &self,
982        q: &Tensor,
983        k: &Tensor,
984        seqlen_offsets: &[usize],
985    ) -> Result<(Tensor, Tensor)> {
986        self.0.forward(q, k, seqlen_offsets)
987    }
988}
989
990/// RoPE for SmolLm3
991#[derive(Debug, Clone)]
992pub struct SmolLm3RotaryEmbedding(RotaryEmbedding);
993
994#[derive(Debug, Clone, Deserialize, Serialize, Default)]
995pub enum SmolLm3RopeType {
996    #[serde(rename = "llama3")]
997    Llama3,
998    #[serde(rename = "linear")]
999    Linear,
1000    #[default]
1001    #[serde(rename = "default")]
1002    Default,
1003}
1004
1005#[derive(Debug, Clone, Deserialize, Serialize, Default)]
1006pub struct SmolLm3RopeConfig {
1007    pub factor: f32,
1008    pub low_freq_factor: Option<f32>,
1009    pub high_freq_factor: Option<f32>,
1010    pub original_max_position_embeddings: Option<usize>,
1011    pub rope_type: SmolLm3RopeType,
1012}
1013
1014fn calculate_default_inv_freq_smollm3(cfg: &smollm3::Config) -> Vec<f32> {
1015    let head_dim = cfg.hidden_size / cfg.num_attention_heads;
1016    (0..head_dim)
1017        .step_by(2)
1018        .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
1019        .collect()
1020}
1021
1022impl SmolLm3RotaryEmbedding {
1023    pub fn new_llama3(
1024        dtype: DType,
1025        cfg: &smollm3::Config,
1026        dev: &Device,
1027        is_gpt_neox: bool,
1028    ) -> Result<Self> {
1029        match &cfg.rope_scaling {
1030            None
1031            | Some(SmolLm3RopeConfig {
1032                rope_type: SmolLm3RopeType::Default,
1033                ..
1034            }) => Ok(Self(RotaryEmbedding::new(
1035                cfg.rope_theta,
1036                cfg.hidden_size / cfg.num_attention_heads,
1037                cfg.max_position_embeddings,
1038                dev,
1039                is_gpt_neox,
1040                dtype,
1041            )?)),
1042            Some(SmolLm3RopeConfig {
1043                rope_type: SmolLm3RopeType::Llama3,
1044                factor,
1045                low_freq_factor,
1046                high_freq_factor,
1047                original_max_position_embeddings,
1048            }) => {
1049                let low_freq_factor = low_freq_factor.context("low_freq_factor is required")?;
1050                let high_freq_factor = high_freq_factor.context("high_freq_factor is required")?;
1051                let original_max_position_embeddings = original_max_position_embeddings
1052                    .context("original_max_position_embeddings is required")?;
1053
1054                let low_freq_wavelen = original_max_position_embeddings as f32 / low_freq_factor;
1055                let high_freq_wavelen = original_max_position_embeddings as f32 / high_freq_factor;
1056
1057                let inv_freq = calculate_default_inv_freq_smollm3(cfg)
1058                    .into_iter()
1059                    .map(|freq| {
1060                        let wavelen = 2. * PI / freq;
1061                        if wavelen < high_freq_wavelen {
1062                            freq
1063                        } else if wavelen > low_freq_wavelen {
1064                            freq / *factor
1065                        } else {
1066                            let smooth = (original_max_position_embeddings as f32 / wavelen
1067                                - low_freq_factor)
1068                                / (high_freq_factor - low_freq_factor);
1069                            (1. - smooth) * freq / *factor + smooth * freq
1070                        }
1071                    })
1072                    .collect::<Vec<_>>();
1073                let inv_freq_len = inv_freq.len();
1074                let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
1075                let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
1076                    .to_dtype(DType::F32)?
1077                    .reshape((cfg.max_position_embeddings, 1))?;
1078                let freqs = t.matmul(&inv_freq)?;
1079                let sin = freqs.sin()?.to_dtype(dtype)?;
1080                let cos = freqs.cos()?.to_dtype(dtype)?;
1081                Ok(Self(RotaryEmbedding {
1082                    sin,
1083                    cos,
1084                    is_gpt_neox,
1085                }))
1086            }
1087            Some(SmolLm3RopeConfig {
1088                rope_type: SmolLm3RopeType::Linear,
1089                factor,
1090                ..
1091            }) => {
1092                let inv_freq_vec = calculate_default_inv_freq_smollm3(cfg)
1093                    .into_iter()
1094                    .map(|freq| freq / *factor)
1095                    .collect::<Vec<_>>();
1096                let inv_freq_len = inv_freq_vec.len();
1097                let inv_freq = Tensor::from_vec(inv_freq_vec, (1, inv_freq_len), dev)?;
1098                let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
1099                    .to_dtype(DType::F32)?
1100                    .reshape((cfg.max_position_embeddings, 1))?;
1101                let freqs = t.matmul(&inv_freq)?;
1102                let sin = freqs.sin()?.to_dtype(dtype)?;
1103                let cos = freqs.cos()?.to_dtype(dtype)?;
1104                Ok(Self(RotaryEmbedding {
1105                    sin,
1106                    cos,
1107                    is_gpt_neox,
1108                }))
1109            }
1110        }
1111    }
1112
1113    pub fn forward(
1114        &self,
1115        q: &Tensor,
1116        k: &Tensor,
1117        seqlen_offsets: &[usize],
1118    ) -> Result<(Tensor, Tensor)> {
1119        self.0.forward(q, k, seqlen_offsets)
1120    }
1121}
1122
1123// https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L107
1124#[derive(Debug, Clone)]
1125pub struct Qwen2VLRotaryEmbedding {
1126    inv_freq: Tensor,
1127    mrope_section: Vec<usize>,
1128}
1129
1130impl Qwen2VLRotaryEmbedding {
1131    pub fn new(
1132        base: f32,
1133        head_dim: usize,
1134        device: &Device,
1135        mrope_section: Vec<usize>,
1136    ) -> Result<Self> {
1137        let inv_freq: Vec<_> = (0..head_dim)
1138            .step_by(2)
1139            .map(|i| 1f32 / base.powf(i as f32 / head_dim as f32))
1140            .collect();
1141        let inv_freq_len = inv_freq.len();
1142        let inv_freq = Tensor::from_vec(inv_freq, (inv_freq_len,), device)?.to_dtype(DType::F32)?;
1143        Ok(Self {
1144            inv_freq,
1145            mrope_section,
1146        })
1147    }
1148
1149    /// (cos, sin)
1150    pub fn compute_cos_sin(&self, position_ids: &Tensor, dtype: DType) -> Result<(Tensor, Tensor)> {
1151        let inv_freq_expanded =
1152            self.inv_freq
1153                .reshape((1, 1, (), 1))?
1154                .repeat((3, position_ids.dim(1)?, 1, 1))?;
1155        let position_ids_expanded = position_ids.unsqueeze(2)?;
1156        let freqs = inv_freq_expanded
1157            .matmul(&position_ids_expanded.to_dtype(inv_freq_expanded.dtype())?)?
1158            .transpose(2, 3)?;
1159        let cos = freqs.cos()?;
1160        let sin = freqs.sin()?;
1161
1162        let cos = Tensor::cat(
1163            &cos.split(&self.mrope_section, D::Minus1)?
1164                .into_iter()
1165                .enumerate()
1166                .map(|(i, m)| m.i(i % 3))
1167                .collect::<Result<Vec<_>>>()?,
1168            D::Minus1,
1169        )?
1170        .squeeze(0)?
1171        .to_dtype(dtype)?
1172        .contiguous()?;
1173        let sin = Tensor::cat(
1174            &sin.split(&self.mrope_section, D::Minus1)?
1175                .into_iter()
1176                .enumerate()
1177                .map(|(i, m)| m.i(i % 3))
1178                .collect::<Result<Vec<_>>>()?,
1179            D::Minus1,
1180        )?
1181        .squeeze(0)?
1182        .to_dtype(dtype)?
1183        .contiguous()?;
1184
1185        Ok((cos, sin))
1186    }
1187
1188    // https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L203
1189    pub fn forward(
1190        &self,
1191        (cos, sin): &(Tensor, Tensor),
1192        q: &mut Tensor,
1193        k: &mut Tensor,
1194    ) -> Result<()> {
1195        *q = candle_nn::rotary_emb::rope(&q.contiguous()?, cos, sin)?;
1196        *k = candle_nn::rotary_emb::rope(&k.contiguous()?, cos, sin)?;
1197        Ok(())
1198    }
1199}
1200
1201#[derive(Debug, Clone)]
1202pub struct Qwen2_5VLRotaryEmbedding {
1203    inv_freq: Tensor,
1204    mrope_section: Vec<usize>,
1205}
1206
1207impl Qwen2_5VLRotaryEmbedding {
1208    pub fn new(
1209        base: f32,
1210        head_dim: usize,
1211        device: &Device,
1212        mrope_section: Vec<usize>,
1213    ) -> Result<Self> {
1214        let inv_freq: Vec<_> = (0..head_dim)
1215            .step_by(2)
1216            .map(|i| 1f32 / base.powf(i as f32 / head_dim as f32))
1217            .collect();
1218        let inv_freq_len = inv_freq.len();
1219        let inv_freq = Tensor::from_vec(inv_freq, (inv_freq_len,), device)?.to_dtype(DType::F32)?;
1220        Ok(Self {
1221            inv_freq,
1222            mrope_section,
1223        })
1224    }
1225
1226    /// (cos, sin)
1227    pub fn compute_cos_sin(&self, position_ids: &Tensor, dtype: DType) -> Result<(Tensor, Tensor)> {
1228        let inv_freq_expanded =
1229            self.inv_freq
1230                .reshape((1, 1, (), 1))?
1231                .repeat((3, position_ids.dim(1)?, 1, 1))?;
1232        let position_ids_expanded = position_ids.unsqueeze(2)?;
1233        let freqs = inv_freq_expanded
1234            .matmul(&position_ids_expanded.to_dtype(inv_freq_expanded.dtype())?)?
1235            .transpose(2, 3)?;
1236        let cos = freqs.cos()?;
1237        let sin = freqs.sin()?;
1238
1239        let cos = Tensor::cat(
1240            &cos.split(&self.mrope_section, D::Minus1)?
1241                .into_iter()
1242                .enumerate()
1243                .map(|(i, m)| m.i(i % 3))
1244                .collect::<Result<Vec<_>>>()?,
1245            D::Minus1,
1246        )?
1247        .squeeze(0)?
1248        .to_dtype(dtype)?
1249        .contiguous()?;
1250        let sin = Tensor::cat(
1251            &sin.split(&self.mrope_section, D::Minus1)?
1252                .into_iter()
1253                .enumerate()
1254                .map(|(i, m)| m.i(i % 3))
1255                .collect::<Result<Vec<_>>>()?,
1256            D::Minus1,
1257        )?
1258        .squeeze(0)?
1259        .to_dtype(dtype)?
1260        .contiguous()?;
1261
1262        Ok((cos, sin))
1263    }
1264
1265    pub fn forward(
1266        &self,
1267        (cos, sin): &(Tensor, Tensor),
1268        q: &mut Tensor,
1269        k: &mut Tensor,
1270    ) -> Result<()> {
1271        *q = candle_nn::rotary_emb::rope(&q.contiguous()?, cos, sin)?;
1272        *k = candle_nn::rotary_emb::rope(&k.contiguous()?, cos, sin)?;
1273        Ok(())
1274    }
1275}
1276
1277#[derive(Debug, Clone)]
1278pub struct Qwen3VLRotaryEmbedding {
1279    inv_freq: Tensor,
1280    mrope_section: Vec<usize>,
1281}
1282
1283impl Qwen3VLRotaryEmbedding {
1284    pub fn new(
1285        base: f32,
1286        head_dim: usize,
1287        device: &Device,
1288        mrope_section: Vec<usize>,
1289    ) -> Result<Self> {
1290        let inv_freq: Vec<_> = (0..head_dim)
1291            .step_by(2)
1292            .map(|i| 1f32 / base.powf(i as f32 / head_dim as f32))
1293            .collect();
1294        let inv_freq_len = inv_freq.len();
1295        let inv_freq = Tensor::from_vec(inv_freq, (inv_freq_len,), device)?.to_dtype(DType::F32)?;
1296        Ok(Self {
1297            inv_freq,
1298            mrope_section,
1299        })
1300    }
1301
1302    /// (cos, sin)
1303    pub fn compute_cos_sin(&self, position_ids: &Tensor, dtype: DType) -> Result<(Tensor, Tensor)> {
1304        let inv_freq_expanded =
1305            self.inv_freq
1306                .reshape((1, 1, (), 1))?
1307                .repeat((3, position_ids.dim(1)?, 1, 1))?;
1308        let position_ids_expanded = position_ids.unsqueeze(2)?;
1309        let freqs = inv_freq_expanded
1310            .matmul(&position_ids_expanded.to_dtype(inv_freq_expanded.dtype())?)?
1311            .transpose(2, 3)?;
1312        let cos = freqs.cos()?;
1313        let sin = freqs.sin()?;
1314
1315        let cos = Tensor::cat(
1316            &cos.split(&self.mrope_section, D::Minus1)?
1317                .into_iter()
1318                .enumerate()
1319                .map(|(i, m)| m.i(i % 3))
1320                .collect::<Result<Vec<_>>>()?,
1321            D::Minus1,
1322        )?
1323        .squeeze(0)?
1324        .to_dtype(dtype)?
1325        .contiguous()?;
1326        let sin = Tensor::cat(
1327            &sin.split(&self.mrope_section, D::Minus1)?
1328                .into_iter()
1329                .enumerate()
1330                .map(|(i, m)| m.i(i % 3))
1331                .collect::<Result<Vec<_>>>()?,
1332            D::Minus1,
1333        )?
1334        .squeeze(0)?
1335        .to_dtype(dtype)?
1336        .contiguous()?;
1337
1338        Ok((cos, sin))
1339    }
1340
1341    pub fn forward(
1342        &self,
1343        (cos, sin): &(Tensor, Tensor),
1344        q: &mut Tensor,
1345        k: &mut Tensor,
1346    ) -> Result<()> {
1347        *q = candle_nn::rotary_emb::rope(&q.contiguous()?, cos, sin)?;
1348        *k = candle_nn::rotary_emb::rope(&k.contiguous()?, cos, sin)?;
1349        Ok(())
1350    }
1351}
1352
1353#[derive(Debug, Clone)]
1354pub struct DeepSeekV2RotaryEmbedding {
1355    sin: Tensor,
1356    cos: Tensor,
1357}
1358
1359#[derive(Debug, Clone, Deserialize, Serialize)]
1360#[serde(untagged)]
1361pub enum DeepSeekV2RopeScaling {
1362    Yarn {
1363        original_max_position_embeddings: usize,
1364        beta_fast: f32,
1365        beta_slow: f32,
1366        mscale: f32,
1367        mscale_all_dim: f32,
1368        factor: f32,
1369        #[serde(rename = "type")]
1370        scaling_type: ScaledRopeType,
1371    },
1372    LinearOrDynamic {
1373        #[serde(rename = "type")]
1374        scaling_type: ScaledRopeType,
1375        factor: f64,
1376    },
1377}
1378
1379pub struct DeepSeekV2RopeConfig {
1380    pub rope_scaling: Option<DeepSeekV2RopeScaling>,
1381    pub max_position_embeddings: usize,
1382    pub rope_theta: f32,
1383    pub qk_rope_head_dim: usize,
1384}
1385
1386impl DeepSeekV2RotaryEmbedding {
1387    fn new_unscaled(cfg: &DeepSeekV2RopeConfig, dtype: DType, dev: &Device) -> Result<Self> {
1388        let max_seq_len = cfg.max_position_embeddings;
1389        let dim = cfg.qk_rope_head_dim;
1390
1391        let inv_freq: Vec<_> = (0..dim)
1392            .step_by(2)
1393            .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / dim as f32))
1394            .collect();
1395        let inv_freq_len = inv_freq.len();
1396        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
1397        let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
1398            .to_dtype(DType::F32)?
1399            .reshape((max_seq_len, 1))?;
1400        let freqs = t.matmul(&inv_freq)?;
1401
1402        let sin = freqs.sin()?.to_dtype(dtype)?;
1403        let cos = freqs.cos()?.to_dtype(dtype)?;
1404
1405        Ok(Self { sin, cos })
1406    }
1407
1408    fn yarn_find_correction_dim(
1409        num_rot: f32,
1410        dim: usize,
1411        base: f32,
1412        max_position_embeddings: usize,
1413    ) -> f32 {
1414        (dim as f32 * (max_position_embeddings as f32 / (num_rot * 2. * PI)).ln())
1415            / (2. * base.ln())
1416    }
1417
1418    fn yarn_find_correction_range(
1419        low_rot: f32,
1420        high_rot: f32,
1421        dim: usize,
1422        base: f32,
1423        max_position_embeddings: usize,
1424    ) -> (f32, f32) {
1425        let low =
1426            Self::yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings).floor();
1427        let high =
1428            Self::yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings).ceil();
1429        (low.max(0.), high.min(dim as f32 - 1.))
1430    }
1431
1432    fn yarn_linear_ramp_mask(min: f32, mut max: f32, dim: usize, dev: &Device) -> Result<Tensor> {
1433        if min == max {
1434            // https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite/blob/604d5664dddd88a0433dbae533b7fe9472482de0/modeling_deepseek.py#L255
1435            max += 0.001;
1436        }
1437        let linear_func =
1438            ((Tensor::arange(0f32, dim as f32, dev)? - min as f64)? / (max as f64 - min as f64))?;
1439        linear_func.clamp(0., 1)
1440    }
1441
1442    pub(crate) fn yarn_get_mscale(scale: f32, mscale: f32) -> f32 {
1443        if scale <= 1. {
1444            return 1.;
1445        }
1446        0.1 * mscale * scale.ln() + 1.
1447    }
1448
1449    #[allow(clippy::too_many_arguments)]
1450    fn new_yarn(
1451        cfg: &DeepSeekV2RopeConfig,
1452        dtype: DType,
1453        dev: &Device,
1454        original_max_position_embeddings: usize,
1455        beta_fast: f32,
1456        beta_slow: f32,
1457        factor: f32,
1458        mscale: f32,
1459        mscale_all_dim: f32,
1460    ) -> Result<Self> {
1461        let freq_extra: Vec<_> = (0..cfg.qk_rope_head_dim)
1462            .step_by(2)
1463            .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / cfg.qk_rope_head_dim as f32))
1464            .collect();
1465        let freq_extra_len = freq_extra.len();
1466        let freq_extra = Tensor::from_vec(freq_extra, freq_extra_len, dev)?;
1467        let freq_inter: Vec<_> = (0..cfg.qk_rope_head_dim)
1468            .step_by(2)
1469            .map(|i| 1f32 / (factor * cfg.rope_theta.powf(i as f32 / cfg.qk_rope_head_dim as f32)))
1470            .collect();
1471        let freq_inter_len = freq_inter.len();
1472        let freq_inter = Tensor::from_vec(freq_inter, (1, freq_inter_len), dev)?;
1473
1474        let (low, high) = Self::yarn_find_correction_range(
1475            beta_fast,
1476            beta_slow,
1477            cfg.qk_rope_head_dim,
1478            cfg.rope_theta,
1479            original_max_position_embeddings,
1480        );
1481        let inv_freq_mask =
1482            (1. - Self::yarn_linear_ramp_mask(low, high, cfg.qk_rope_head_dim / 2, dev)?)?;
1483        let inv_freq = freq_inter
1484            .broadcast_mul(&(1. - &inv_freq_mask)?)?
1485            .broadcast_add(&freq_extra.broadcast_mul(&inv_freq_mask)?)?;
1486
1487        let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
1488            .to_dtype(DType::F32)?
1489            .reshape((cfg.max_position_embeddings, 1))?;
1490        let freqs = t.matmul(&inv_freq)?;
1491
1492        let mscale =
1493            Self::yarn_get_mscale(factor, mscale) / Self::yarn_get_mscale(factor, mscale_all_dim);
1494        let sin = (freqs.sin()? * mscale as f64)?.to_dtype(dtype)?;
1495        let cos = (freqs.cos()? * mscale as f64)?.to_dtype(dtype)?;
1496
1497        Ok(Self { sin, cos })
1498    }
1499
1500    pub fn new(cfg: &DeepSeekV2RopeConfig, dtype: DType, dev: &Device) -> Result<Self> {
1501        match &cfg.rope_scaling {
1502            Some(DeepSeekV2RopeScaling::LinearOrDynamic {
1503                scaling_type: _,
1504                factor: _,
1505            }) => candle_core::bail!("linear and dynamic rope are not implemented yet!"),
1506            Some(DeepSeekV2RopeScaling::Yarn {
1507                original_max_position_embeddings,
1508                beta_fast,
1509                beta_slow,
1510                factor,
1511                mscale,
1512                mscale_all_dim,
1513                scaling_type: _,
1514            }) => Self::new_yarn(
1515                cfg,
1516                dtype,
1517                dev,
1518                *original_max_position_embeddings,
1519                *beta_fast,
1520                *beta_slow,
1521                *factor,
1522                *mscale,
1523                *mscale_all_dim,
1524            ),
1525            None => Self::new_unscaled(cfg, dtype, dev),
1526        }
1527    }
1528
1529    pub fn forward(
1530        &self,
1531        q: &Tensor,
1532        k: &Tensor,
1533        seqlen_offsets: &[usize],
1534    ) -> Result<(Tensor, Tensor)> {
1535        let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
1536
1537        if seqlen_offsets.len() == 1 {
1538            let cos = self.cos.narrow(0, seqlen_offsets[0], seq_len)?;
1539            let sin = self.sin.narrow(0, seqlen_offsets[0], seq_len)?;
1540            let q_embed = candle_nn::rotary_emb::rope_i(&q.contiguous()?, &cos, &sin)?;
1541            let k_embed = candle_nn::rotary_emb::rope_i(&k.contiguous()?, &cos, &sin)?;
1542            Ok((q_embed, k_embed))
1543        } else {
1544            let mut q_embeds = Vec::new();
1545            let mut k_embeds = Vec::new();
1546            for (i, offset) in seqlen_offsets.iter().enumerate() {
1547                let cos = self.cos.narrow(0, *offset, seq_len)?;
1548                let sin = self.sin.narrow(0, *offset, seq_len)?;
1549                let q_embed = candle_nn::rotary_emb::rope_i(
1550                    &q.i(i)?.unsqueeze(0)?.contiguous()?,
1551                    &cos,
1552                    &sin,
1553                )?;
1554                let k_embed = candle_nn::rotary_emb::rope_i(
1555                    &k.i(i)?.unsqueeze(0)?.contiguous()?,
1556                    &cos,
1557                    &sin,
1558                )?;
1559                q_embeds.push(q_embed);
1560                k_embeds.push(k_embed);
1561            }
1562            Ok((Tensor::cat(&q_embeds, 0)?, Tensor::cat(&k_embeds, 0)?))
1563        }
1564    }
1565}
1566
1567#[derive(Debug, Clone)]
1568pub struct Phi4MMRotaryEmbedding {
1569    short_sin: Tensor,
1570    short_cos: Tensor,
1571    long_cos: Option<Tensor>,
1572    long_sin: Option<Tensor>,
1573    original_max_position_embeddings: usize,
1574}
1575
1576#[derive(Debug, Clone, Default, Deserialize, Serialize)]
1577#[serde(rename_all = "lowercase")]
1578pub enum Phi4MMScaledRopeType {
1579    #[serde(alias = "longrope")]
1580    LongRope,
1581    #[default]
1582    Default,
1583}
1584
1585#[derive(Debug, Clone, Deserialize, Serialize)]
1586pub struct Phi4MMRopeScalingConfig {
1587    short_factor: Option<Vec<f64>>,
1588    long_factor: Option<Vec<f64>>,
1589    #[serde(rename = "type")]
1590    scaling_type: Phi4MMScaledRopeType,
1591}
1592
1593impl Phi4MMRotaryEmbedding {
1594    fn new_unscaled(cfg: &Phi4MMConfig, dtype: DType, dev: &Device) -> Result<Self> {
1595        let max_seq_len = cfg.max_position_embeddings;
1596        let dim = (cfg.head_dim() as f64 * cfg.partial_rotary_factor) as usize;
1597
1598        let inv_freq: Vec<_> = (0..dim)
1599            .step_by(2)
1600            .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
1601            .collect();
1602        let inv_freq_len = inv_freq.len();
1603        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
1604        let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
1605            .to_dtype(DType::F32)?
1606            .reshape((max_seq_len, 1))?;
1607        let freqs = t.matmul(&inv_freq)?;
1608        let sin = freqs.sin()?.to_dtype(dtype)?;
1609        let cos = freqs.cos()?.to_dtype(dtype)?;
1610        Ok(Self {
1611            short_cos: cos,
1612            short_sin: sin,
1613            long_cos: None,
1614            long_sin: None,
1615            original_max_position_embeddings: cfg.original_max_position_embeddings,
1616        })
1617    }
1618
1619    #[allow(clippy::too_many_arguments)]
1620    fn new_longrope(
1621        short_factor: &[f64],
1622        long_factor: &[f64],
1623        cfg: &Phi4MMConfig,
1624        dtype: DType,
1625        dev: &Device,
1626    ) -> Result<Self> {
1627        let max_seq_len = cfg.max_position_embeddings;
1628        let dim = (cfg.head_dim() as f64 * cfg.partial_rotary_factor) as usize;
1629
1630        // Calculate scale
1631        let scale =
1632            cfg.max_position_embeddings as f64 / cfg.original_max_position_embeddings as f64;
1633        let scaling_factor = if scale <= 1.0 {
1634            1.0
1635        } else {
1636            (1.0 + scale.ln() / (cfg.original_max_position_embeddings as f64).ln()).sqrt()
1637        };
1638
1639        // Short cos/sin
1640        let inv_freq_short: Vec<_> = (0..dim)
1641            .step_by(2)
1642            .enumerate()
1643            .map(|(k, i)| {
1644                1f32 / (short_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32
1645            })
1646            .collect();
1647        let inv_freq_len_short = inv_freq_short.len();
1648        let inv_freq_short = Tensor::from_vec(inv_freq_short, (1, inv_freq_len_short), dev)?;
1649        let t_short = Tensor::arange(0u32, max_seq_len as u32, dev)?
1650            .to_dtype(DType::F32)?
1651            .reshape((max_seq_len, 1))?;
1652        let freqs_short = t_short.matmul(&inv_freq_short)?;
1653        let sin_short = (freqs_short.sin()?.to_dtype(dtype)? * scaling_factor)?;
1654        let cos_short = (freqs_short.cos()?.to_dtype(dtype)? * scaling_factor)?;
1655
1656        // Long cos/sin
1657        let inv_freq_long: Vec<_> = (0..dim)
1658            .step_by(2)
1659            .enumerate()
1660            .map(|(k, i)| {
1661                1f32 / (long_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32
1662            })
1663            .collect();
1664        let inv_freq_len_long = inv_freq_long.len();
1665        let inv_freq_long = Tensor::from_vec(inv_freq_long, (1, inv_freq_len_long), dev)?;
1666        let t_long = Tensor::arange(0u32, max_seq_len as u32, dev)?
1667            .to_dtype(DType::F32)?
1668            .reshape((max_seq_len, 1))?;
1669        let freqs_long = t_long.matmul(&inv_freq_long)?;
1670        let sin_long = (freqs_long.sin()?.to_dtype(dtype)? * scaling_factor)?;
1671        let cos_long = (freqs_long.cos()?.to_dtype(dtype)? * scaling_factor)?;
1672
1673        Ok(Self {
1674            short_cos: cos_short,
1675            short_sin: sin_short,
1676            long_cos: Some(cos_long),
1677            long_sin: Some(sin_long),
1678            original_max_position_embeddings: cfg.original_max_position_embeddings,
1679        })
1680    }
1681
1682    pub fn new(dtype: DType, cfg: &Phi4MMConfig, dev: &Device) -> Result<Self> {
1683        match &cfg.rope_scaling {
1684            Some(Phi4MMRopeScalingConfig {
1685                scaling_type: Phi4MMScaledRopeType::LongRope,
1686                short_factor: Some(short_factor),
1687                long_factor: Some(long_factor),
1688            }) => Self::new_longrope(short_factor, long_factor, cfg, dtype, dev),
1689
1690            _ => Self::new_unscaled(cfg, dtype, dev),
1691        }
1692    }
1693
1694    /// Returns (sin, cos) taking into account LongRope
1695    fn get_long_or_short_sin_cos(&self, position_ids: &[usize]) -> (&Tensor, &Tensor) {
1696        if self.long_cos.is_none() {
1697            return (&self.short_sin, &self.short_cos);
1698        }
1699        let seq_len = position_ids.iter().max().unwrap() + 1;
1700        if seq_len > self.original_max_position_embeddings {
1701            (
1702                self.long_sin.as_ref().unwrap(),
1703                self.long_cos.as_ref().unwrap(),
1704            )
1705        } else {
1706            (&self.short_sin, &self.short_cos)
1707        }
1708    }
1709
1710    pub fn forward(
1711        &self,
1712        q: &Tensor,
1713        k: &Tensor,
1714        seqlen_offsets: &[usize],
1715        position_ids: &[usize],
1716    ) -> Result<(Tensor, Tensor)> {
1717        let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
1718        let (sin, cos) = self.get_long_or_short_sin_cos(position_ids);
1719
1720        let rot_dim = cos.dim(D::Minus1)? * 2;
1721        let q_rot = q.narrow(D::Minus1, 0, rot_dim)?;
1722        let q_pass = q.narrow(D::Minus1, rot_dim, q.dim(D::Minus1)? - rot_dim)?;
1723        let k_rot = k.narrow(D::Minus1, 0, rot_dim)?;
1724        let k_pass = k.narrow(D::Minus1, rot_dim, k.dim(D::Minus1)? - rot_dim)?;
1725
1726        let (q_rot, k_rot) = if seqlen_offsets.len() == 1 {
1727            let cos = cos.narrow(0, seqlen_offsets[0], seq_len)?;
1728            let sin = sin.narrow(0, seqlen_offsets[0], seq_len)?;
1729            let q_embed = candle_nn::rotary_emb::rope(&q_rot.contiguous()?, &cos, &sin)?;
1730            let k_embed = candle_nn::rotary_emb::rope(&k_rot.contiguous()?, &cos, &sin)?;
1731            (q_embed, k_embed)
1732        } else {
1733            let mut q_embeds = Vec::new();
1734            let mut k_embeds = Vec::new();
1735            for (i, offset) in seqlen_offsets.iter().enumerate() {
1736                let cos = cos.narrow(0, *offset, seq_len)?;
1737                let sin = sin.narrow(0, *offset, seq_len)?;
1738                let q_embed = candle_nn::rotary_emb::rope(
1739                    &q_rot.i(i)?.unsqueeze(0)?.contiguous()?,
1740                    &cos,
1741                    &sin,
1742                )?;
1743                let k_embed = candle_nn::rotary_emb::rope(
1744                    &k_rot.i(i)?.unsqueeze(0)?.contiguous()?,
1745                    &cos,
1746                    &sin,
1747                )?;
1748                q_embeds.push(q_embed);
1749                k_embeds.push(k_embed);
1750            }
1751            let q_rot = Tensor::cat(&q_embeds, 0)?;
1752            let k_rot = Tensor::cat(&k_embeds, 0)?;
1753            (q_rot, k_rot)
1754        };
1755
1756        Ok((
1757            Tensor::cat(&[q_rot, q_pass], D::Minus1)?.contiguous()?,
1758            Tensor::cat(&[k_rot, k_pass], D::Minus1)?.contiguous()?,
1759        ))
1760    }
1761}
1762
1763#[derive(Debug, Clone)]
1764pub struct Gemma3nRotaryEmbedding(RotaryEmbedding);
1765
1766#[derive(Debug, Clone, Deserialize, Serialize)]
1767#[serde(rename_all = "lowercase")]
1768pub enum Gemma3nScaledRopeType {
1769    #[serde(alias = "linear")]
1770    Linear,
1771}
1772
1773#[derive(Debug, Clone, Deserialize, Serialize)]
1774pub struct Gemma3nRopeScalingConfig {
1775    factor: f64,
1776    rope_type: Gemma3nScaledRopeType,
1777}
1778
1779impl Gemma3nRotaryEmbedding {
1780    fn new_linear(
1781        cfg: &Gemma3nTextConfig,
1782        factor: f64,
1783        is_gpt_neox: bool,
1784        dtype: DType,
1785        dev: &Device,
1786    ) -> Result<Self> {
1787        let max_seq_len = cfg.max_position_embeddings;
1788        let dim = cfg.head_dim;
1789
1790        let inv_freq: Vec<_> = (0..dim)
1791            .step_by(2)
1792            .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
1793            .collect();
1794        let inv_freq_len = inv_freq.len();
1795        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
1796        let inv_freq = (inv_freq / factor)?;
1797
1798        let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
1799            .to_dtype(DType::F32)?
1800            .reshape((max_seq_len, 1))?;
1801        let freqs = t.matmul(&inv_freq)?;
1802        let sin = freqs.sin()?.to_dtype(dtype)?;
1803        let cos = freqs.cos()?.to_dtype(dtype)?;
1804        Ok(Self(RotaryEmbedding {
1805            cos,
1806            sin,
1807            is_gpt_neox,
1808        }))
1809    }
1810
1811    pub fn new(
1812        is_gpt_neox: bool,
1813        dtype: DType,
1814        cfg: &Gemma3nTextConfig,
1815        dev: &Device,
1816    ) -> Result<Self> {
1817        match &cfg.rope_scaling {
1818            Some(Gemma3RopeScalingConfig {
1819                rope_type: Gemma3ScaledRopeType::Linear,
1820                factor,
1821            }) => Self::new_linear(cfg, *factor, is_gpt_neox, dtype, dev),
1822
1823            _ => Self::new_linear(cfg, 1.0, is_gpt_neox, dtype, dev),
1824        }
1825    }
1826
1827    pub fn get_cos_sin(&self) -> Result<(Tensor, Tensor)> {
1828        self.0.get_cos_sin()
1829    }
1830
1831    pub fn forward(
1832        &self,
1833        q: &Tensor,
1834        k: &Tensor,
1835        seqlen_offsets: &[usize],
1836    ) -> Result<(Tensor, Tensor)> {
1837        self.0.forward(q, k, seqlen_offsets)
1838    }
1839}
1840
1841#[derive(Debug, Clone)]
1842pub struct Gemma3RotaryEmbedding(RotaryEmbedding);
1843
1844#[derive(Debug, Clone, Deserialize, Serialize)]
1845#[serde(rename_all = "lowercase")]
1846pub enum Gemma3ScaledRopeType {
1847    #[serde(alias = "linear")]
1848    Linear,
1849}
1850
1851#[derive(Debug, Clone, Deserialize, Serialize)]
1852pub struct Gemma3RopeScalingConfig {
1853    factor: f64,
1854    rope_type: Gemma3ScaledRopeType,
1855}
1856
1857impl Gemma3RotaryEmbedding {
1858    fn new_linear(
1859        cfg: &Gemma3TextConfig,
1860        factor: f64,
1861        is_gpt_neox: bool,
1862        dtype: DType,
1863        dev: &Device,
1864    ) -> Result<Self> {
1865        let max_seq_len = cfg.max_position_embeddings;
1866        let dim = cfg.head_dim;
1867
1868        let inv_freq: Vec<_> = (0..dim)
1869            .step_by(2)
1870            .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
1871            .collect();
1872        let inv_freq_len = inv_freq.len();
1873        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
1874        let inv_freq = (inv_freq / factor)?;
1875
1876        let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
1877            .to_dtype(DType::F32)?
1878            .reshape((max_seq_len, 1))?;
1879        let freqs = t.matmul(&inv_freq)?;
1880        let sin = freqs.sin()?.to_dtype(dtype)?;
1881        let cos = freqs.cos()?.to_dtype(dtype)?;
1882        Ok(Self(RotaryEmbedding {
1883            cos,
1884            sin,
1885            is_gpt_neox,
1886        }))
1887    }
1888
1889    pub fn new(
1890        is_gpt_neox: bool,
1891        dtype: DType,
1892        cfg: &Gemma3TextConfig,
1893        dev: &Device,
1894    ) -> Result<Self> {
1895        match &cfg.rope_scaling {
1896            Some(Gemma3RopeScalingConfig {
1897                rope_type: Gemma3ScaledRopeType::Linear,
1898                factor,
1899            }) => Self::new_linear(cfg, *factor, is_gpt_neox, dtype, dev),
1900
1901            _ => Self::new_linear(cfg, 1.0, is_gpt_neox, dtype, dev),
1902        }
1903    }
1904
1905    fn new_linear_embedding_gemma(
1906        cfg: &EmbeddingGemmaConfig,
1907        factor: f64,
1908        is_gpt_neox: bool,
1909        dtype: DType,
1910        dev: &Device,
1911    ) -> Result<Self> {
1912        let max_seq_len = cfg.max_position_embeddings;
1913        let dim = cfg.head_dim;
1914
1915        let inv_freq: Vec<_> = (0..dim)
1916            .step_by(2)
1917            .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
1918            .collect();
1919        let inv_freq_len = inv_freq.len();
1920        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
1921        let inv_freq = (inv_freq / factor)?;
1922
1923        let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
1924            .to_dtype(DType::F32)?
1925            .reshape((max_seq_len, 1))?;
1926        let freqs = t.matmul(&inv_freq)?;
1927        let sin = freqs.sin()?.to_dtype(dtype)?;
1928        let cos = freqs.cos()?.to_dtype(dtype)?;
1929        Ok(Self(RotaryEmbedding {
1930            cos,
1931            sin,
1932            is_gpt_neox,
1933        }))
1934    }
1935
1936    pub fn new_embedding_gemma(
1937        is_gpt_neox: bool,
1938        dtype: DType,
1939        cfg: &EmbeddingGemmaConfig,
1940        dev: &Device,
1941    ) -> Result<Self> {
1942        match &cfg.rope_scaling {
1943            Some(Gemma3RopeScalingConfig {
1944                rope_type: Gemma3ScaledRopeType::Linear,
1945                factor,
1946            }) => Self::new_linear_embedding_gemma(cfg, *factor, is_gpt_neox, dtype, dev),
1947
1948            _ => Self::new_linear_embedding_gemma(cfg, 1.0, is_gpt_neox, dtype, dev),
1949        }
1950    }
1951
1952    pub fn forward(
1953        &self,
1954        q: &Tensor,
1955        k: &Tensor,
1956        seqlen_offsets: &[usize],
1957    ) -> Result<(Tensor, Tensor)> {
1958        self.0.forward(q, k, seqlen_offsets)
1959    }
1960}
1961
1962pub struct DiaRotaryEmbedding {
1963    timescale: Tensor,
1964    dtype: DType,
1965}
1966
1967impl DiaRotaryEmbedding {
1968    pub fn new(
1969        min_timescale: f32,
1970        max_timescale: f32,
1971        head_dim: usize,
1972        device: &Device,
1973        dtype: DType,
1974    ) -> Result<Self> {
1975        assert_eq!(head_dim % 2, 0);
1976        let half_embedding_dim = head_dim / 2;
1977
1978        let fraction = (0..half_embedding_dim).map(|i| 2f32 * i as f32 / head_dim as f32);
1979        let timescale = fraction
1980            .into_iter()
1981            .map(|x| min_timescale * (max_timescale / min_timescale).powf(x))
1982            .collect::<Vec<_>>();
1983
1984        let timescale_len = timescale.len();
1985        let timescale = Tensor::from_vec(timescale, timescale_len, device)?;
1986
1987        Ok(Self { timescale, dtype })
1988    }
1989
1990    pub fn forward(&self, xs: &Tensor, positions: &Tensor) -> Result<Tensor> {
1991        let freqs = positions
1992            .unsqueeze(D::Minus1)?
1993            .unsqueeze(D::Minus1)?
1994            .broadcast_div(&self.timescale)?;
1995
1996        let sin = freqs.sin()?.to_dtype(self.dtype)?;
1997        let cos = freqs.cos()?.to_dtype(self.dtype)?;
1998
1999        let split = xs.chunk(2, D::Minus1)?;
2000        let first_half = &split[0];
2001        let second_half = &split[1];
2002
2003        let first_part = (first_half.broadcast_mul(&cos)? - second_half.broadcast_mul(&sin)?)?;
2004        let second_part = (second_half.broadcast_mul(&cos)? + first_half.broadcast_mul(&sin)?)?;
2005
2006        Tensor::cat(&[first_part, second_part], D::Minus1)
2007    }
2008}
2009#[derive(Debug, Clone)]
2010pub struct QLinear {
2011    inner: QMatMul,
2012    bias: Option<Tensor>,
2013    dtype: DType,
2014}
2015
2016impl QLinear {
2017    pub fn new<R: std::io::Read + std::io::Seek>(
2018        ct: &mut Content<'_, R>,
2019        name: &str,
2020        device: &Device,
2021    ) -> Result<Self> {
2022        let w = ct.tensor(&format!("{name}.weight"), device)?;
2023        let b = ct.tensor(&format!("{name}.bias"), device)?;
2024        let inner = QMatMul::from_qtensor(w)?;
2025        let bias = b.dequantize(device)?;
2026        Ok(Self {
2027            inner,
2028            bias: Some(bias),
2029            dtype: DType::F32,
2030        })
2031    }
2032
2033    pub fn from_linear(linear: Linear) -> Self {
2034        Self {
2035            inner: QMatMul::Tensor(linear.weight().clone()),
2036            bias: linear.bias().cloned(),
2037            dtype: linear.weight().dtype(),
2038        }
2039    }
2040
2041    pub fn from_parts(w: Tensor, b: Option<Tensor>) -> Self {
2042        let dtype = w.dtype();
2043        Self {
2044            inner: QMatMul::Tensor(w),
2045            bias: b,
2046            dtype,
2047        }
2048    }
2049
2050    pub fn from_qparts(w: QTensor, b: Option<Tensor>) -> Self {
2051        if let Some(ref b) = b {
2052            assert_eq!(b.dtype(), DType::F32);
2053        }
2054        Self {
2055            inner: QMatMul::QTensor(Arc::new(w)),
2056            bias: b,
2057            dtype: DType::F32,
2058        }
2059    }
2060
2061    pub fn from_old_and_qmatmul(inner: QMatMul, old: &Self) -> Self {
2062        Self {
2063            inner,
2064            bias: old.bias.clone(),
2065            dtype: old.dtype,
2066        }
2067    }
2068
2069    pub fn inner(&mut self) -> &mut QMatMul {
2070        &mut self.inner
2071    }
2072
2073    pub fn inner_ref(&self) -> &QMatMul {
2074        &self.inner
2075    }
2076
2077    pub fn is_quant(&self) -> bool {
2078        matches!(self.inner, QMatMul::QTensor(_))
2079    }
2080
2081    pub fn bias(&self) -> Option<&Tensor> {
2082        self.bias.as_ref()
2083    }
2084
2085    pub fn bias_mut(&mut self) -> Option<&mut Tensor> {
2086        self.bias.as_mut()
2087    }
2088}
2089
2090impl Module for QLinear {
2091    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2092        let xs = if self.is_quant() {
2093            xs.to_dtype(DType::F32)?
2094        } else {
2095            xs.clone()
2096        };
2097        if let Some(bias) = &self.bias {
2098            self.inner
2099                .forward(&xs)?
2100                .broadcast_add(bias)?
2101                .to_dtype(self.dtype)
2102        } else {
2103            self.inner.forward(&xs)?.to_dtype(self.dtype)
2104        }
2105    }
2106}
2107
2108#[derive(Debug, Clone)]
2109pub struct RotaryEmbedding {
2110    cos: Tensor,
2111    sin: Tensor,
2112    is_gpt_neox: bool,
2113}
2114
2115impl RotaryEmbedding {
2116    pub fn new(
2117        base: f32,
2118        head_dim: usize,
2119        max_position_embeddings: usize,
2120        device: &Device,
2121        is_gpt_neox: bool,
2122        dtype: DType,
2123    ) -> Result<Self> {
2124        let inv_freq: Vec<_> = (0..head_dim)
2125            .step_by(2)
2126            .map(|i| 1f32 / base.powf(i as f32 / head_dim as f32))
2127            .collect();
2128        let inv_freq_len = inv_freq.len();
2129        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), device)?;
2130        let t = Tensor::arange(0u32, max_position_embeddings as u32, device)?
2131            .to_dtype(DType::F32)?
2132            .reshape((max_position_embeddings, 1))?;
2133        let freqs = t.matmul(&inv_freq)?;
2134        let sin = freqs.sin()?.to_dtype(dtype)?;
2135        let cos = freqs.cos()?.to_dtype(dtype)?;
2136
2137        Ok(Self {
2138            cos,
2139            sin,
2140            is_gpt_neox,
2141        })
2142    }
2143
2144    pub fn get_cos_sin(&self) -> Result<(Tensor, Tensor)> {
2145        Ok((self.cos.clone(), self.sin.clone()))
2146    }
2147
2148    pub fn new_partial(
2149        base: f32,
2150        rot_dim: usize,
2151        max_position_embeddings: usize,
2152        device: &Device,
2153        is_gpt_neox: bool,
2154        dtype: DType,
2155    ) -> Result<Self> {
2156        let inv_freq: Vec<_> = (0..rot_dim)
2157            .step_by(2)
2158            .map(|i| 1f32 / base.powf(i as f32 / rot_dim as f32))
2159            .collect();
2160        let inv_freq_len = inv_freq.len();
2161        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), device)?;
2162        let t = Tensor::arange(0u32, max_position_embeddings as u32, device)?
2163            .to_dtype(DType::F32)?
2164            .reshape((max_position_embeddings, 1))?;
2165        let freqs = t.matmul(&inv_freq)?;
2166        let sin = freqs.sin()?.to_dtype(dtype)?;
2167        let cos = freqs.cos()?.to_dtype(dtype)?;
2168
2169        Ok(Self {
2170            cos,
2171            sin,
2172            is_gpt_neox,
2173        })
2174    }
2175
2176    pub fn forward(
2177        &self,
2178        q: &Tensor,
2179        k: &Tensor,
2180        seqlen_offsets: &[usize],
2181    ) -> Result<(Tensor, Tensor)> {
2182        let (b_sz, qh, seq_len, n_embd) = q.dims4()?;
2183        let (_b_sz, kh, _seq_len, __n_embd) = k.dims4()?;
2184
2185        let rope = if self.is_gpt_neox {
2186            candle_nn::rotary_emb::rope
2187        } else {
2188            candle_nn::rotary_emb::rope_i
2189        };
2190
2191        if cfg!(feature = "cuda") && qh == kh {
2192            let (cos, sin) = if seqlen_offsets.len() == 1 {
2193                (
2194                    self.cos.narrow(0, seqlen_offsets[0], seq_len)?,
2195                    self.sin.narrow(0, seqlen_offsets[0], seq_len)?,
2196                )
2197            } else {
2198                let mut cos_s = Vec::new();
2199                let mut sin_s = Vec::new();
2200                for offset in seqlen_offsets {
2201                    cos_s.push(self.cos.narrow(0, *offset, seq_len)?);
2202                    sin_s.push(self.sin.narrow(0, *offset, seq_len)?);
2203                }
2204                (Tensor::cat(&cos_s, 0)?, Tensor::cat(&sin_s, 0)?)
2205            };
2206
2207            let q_embed = q.transpose(1, 2)?.flatten(0, 1)?;
2208            let k_embed = k.transpose(1, 2)?.flatten(0, 1)?;
2209            mistralrs_quant::rotary::apply_rotary_inplace(
2210                &q_embed,
2211                &k_embed,
2212                &cos,
2213                &sin,
2214                self.is_gpt_neox,
2215            )?;
2216            let mut q = q_embed
2217                .reshape((b_sz, seq_len, qh, n_embd))?
2218                .transpose(1, 2)?;
2219            let mut k = k_embed
2220                .reshape((b_sz, seq_len, kh, n_embd))?
2221                .transpose(1, 2)?;
2222            if !(cfg!(feature = "flash-attn") || cfg!(feature = "flash-attn-v3")) {
2223                q = q.contiguous()?;
2224                k = k.contiguous()?;
2225            }
2226            Ok((q, k))
2227        } else if seqlen_offsets.len() == 1 {
2228            let cos = self.cos.narrow(0, seqlen_offsets[0], seq_len)?;
2229            let sin = self.sin.narrow(0, seqlen_offsets[0], seq_len)?;
2230            let q_embed = rope(&q.contiguous()?, &cos, &sin)?;
2231            let k_embed = rope(&k.contiguous()?, &cos, &sin)?;
2232            Ok((q_embed, k_embed))
2233        } else {
2234            let mut q_embeds = Vec::new();
2235            let mut k_embeds = Vec::new();
2236            for (i, offset) in seqlen_offsets.iter().enumerate() {
2237                let cos = self.cos.narrow(0, *offset, seq_len)?;
2238                let sin = self.sin.narrow(0, *offset, seq_len)?;
2239                let q_embed = rope(&q.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
2240                let k_embed = rope(&k.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
2241                q_embeds.push(q_embed);
2242                k_embeds.push(k_embed);
2243            }
2244            Ok((Tensor::cat(&q_embeds, 0)?, Tensor::cat(&k_embeds, 0)?))
2245        }
2246    }
2247}
2248
2249#[derive(Debug, Clone, Copy, PartialEq, Deserialize, Serialize, Default)]
2250#[serde(rename_all = "lowercase")]
2251pub enum Activation {
2252    #[default]
2253    #[serde(alias = "gelu")]
2254    Gelu,
2255    #[serde(alias = "gelu_new")]
2256    NewGelu,
2257    Relu,
2258    Relu2,
2259    Relu6,
2260    Silu,
2261    Sigmoid,
2262    HardSigmoid,
2263    Swiglu,
2264    Swish,
2265    HardSwish,
2266    Elu(f64),
2267    LeakyRelu(f64),
2268    #[serde(alias = "gelu_pytorch_tanh")]
2269    GeluPytorchTanh,
2270    QuickGelu,
2271}
2272
2273impl Module for Activation {
2274    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2275        match self {
2276            Self::Gelu => xs.gelu_erf(),
2277            // https://github.com/huggingface/transformers/blob/12f043eaeaabfef6f6efea411d98e6f6d3c094b7/src/transformers/activations.py#L49-L78
2278            Self::NewGelu => xs.gelu(),
2279            Self::Relu => xs.relu(),
2280            Self::Relu2 => xs.relu()?.sqr(),
2281            Self::Relu6 => xs.clamp(0f32, 6f32),
2282            Self::Silu => xs.silu(),
2283            Self::Sigmoid => candle_nn::ops::sigmoid(xs),
2284            Self::HardSigmoid => candle_nn::ops::hard_sigmoid(xs),
2285            Self::Swiglu => candle_nn::ops::swiglu(xs),
2286            Self::Swish => xs * candle_nn::ops::sigmoid(xs)?,
2287            Self::HardSwish => xs * candle_nn::ops::hard_sigmoid(xs)?,
2288            &Self::Elu(alpha) => xs.elu(alpha),
2289            &Self::LeakyRelu(negative_slope) => candle_nn::ops::leaky_relu(xs, negative_slope),
2290            Self::GeluPytorchTanh => xs.gelu(),
2291            Self::QuickGelu => xs * candle_nn::ops::sigmoid(&(xs * 1.702f64)?),
2292        }
2293    }
2294}
2295
2296impl TryInto<candle_nn::Activation> for Activation {
2297    type Error = candle_core::Error;
2298
2299    fn try_into(self) -> Result<candle_nn::Activation> {
2300        match self {
2301            Self::Gelu => Ok(candle_nn::Activation::Gelu),
2302            Self::Relu => Ok(candle_nn::Activation::Relu),
2303            Self::Silu => Ok(candle_nn::Activation::Silu),
2304            Self::NewGelu => Ok(candle_nn::Activation::NewGelu),
2305            Self::Relu2 => Ok(candle_nn::Activation::Relu2),
2306            Self::Relu6 => Ok(candle_nn::Activation::Relu6),
2307            Self::Sigmoid => Ok(candle_nn::Activation::Sigmoid),
2308            Self::HardSigmoid => Ok(candle_nn::Activation::HardSigmoid),
2309            Self::Swiglu => Ok(candle_nn::Activation::Swiglu),
2310            Self::Swish => Ok(candle_nn::Activation::Swish),
2311            Self::HardSwish => Ok(candle_nn::Activation::HardSwish),
2312            Self::Elu(x) => Ok(candle_nn::Activation::Elu(x)),
2313            Self::LeakyRelu(x) => Ok(candle_nn::Activation::LeakyRelu(x)),
2314            Self::GeluPytorchTanh => Ok(candle_nn::Activation::GeluPytorchTanh),
2315            Self::QuickGelu => candle_core::bail!("No mapping to candle_nn for QuickGelu"),
2316        }
2317    }
2318}
2319
2320#[derive(Debug, Clone, Copy, PartialEq, Eq)]
2321pub struct Conv3dConfig {
2322    pub padding: usize,
2323    pub stride: usize,
2324    pub dilation: usize,
2325    pub groups: usize,
2326}
2327
2328impl Default for Conv3dConfig {
2329    fn default() -> Self {
2330        Self {
2331            padding: 0,
2332            stride: 1,
2333            dilation: 1,
2334            groups: 1,
2335        }
2336    }
2337}
2338
2339pub struct Conv3dNoBias {
2340    conv2d_1: Conv2d,
2341    conv2d_2: Conv2d,
2342}
2343
2344impl Conv3dNoBias {
2345    pub fn new(
2346        in_channels: usize,
2347        out_channels: usize,
2348        kernel_sizes: [usize; 3],
2349        cfg: Conv3dConfig,
2350        vb: ShardedVarBuilder,
2351    ) -> Result<Self> {
2352        let ws = vb.get(
2353            (
2354                out_channels,
2355                in_channels / cfg.groups,
2356                kernel_sizes[0],
2357                kernel_sizes[1],
2358                kernel_sizes[2],
2359            ),
2360            "weight",
2361        )?;
2362
2363        // Split on temporal dimension
2364        // https://github.com/pytorch/pytorch/issues/139066
2365
2366        let w1 = ws.i((.., .., 0, .., ..))?;
2367        let w2 = ws.i((.., .., 1, .., ..))?;
2368
2369        let cfg = Conv2dConfig {
2370            padding: cfg.padding,
2371            stride: cfg.stride,
2372            dilation: cfg.dilation,
2373            groups: cfg.groups,
2374            cudnn_fwd_algo: None,
2375        };
2376
2377        Ok(Self {
2378            conv2d_1: Conv2d::new(w1.contiguous()?, None, cfg),
2379            conv2d_2: Conv2d::new(w2.contiguous()?, None, cfg),
2380        })
2381    }
2382
2383    pub fn weight(&self) -> Result<Tensor> {
2384        let w1 = self.conv2d_1.weight().clone().unsqueeze(2)?;
2385        let w2 = self.conv2d_2.weight().clone().unsqueeze(2)?;
2386        Tensor::cat(&[w1, w2], 2)
2387    }
2388}
2389
2390impl Module for Conv3dNoBias {
2391    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2392        let xs1 = xs.i((.., .., 0, .., ..))?;
2393        let xs2 = xs.i((.., .., 1, .., ..))?;
2394
2395        (Convolution.forward_2d(&self.conv2d_1, &xs1)?
2396            + Convolution.forward_2d(&self.conv2d_2, &xs2)?)?
2397        .unsqueeze(2)
2398    }
2399}
2400
2401pub trait TensorInfExtend {
2402    fn is_inf(&self) -> Result<Self>
2403    where
2404        Self: Sized;
2405    fn any(&self) -> Result<bool>;
2406}
2407
2408impl TensorInfExtend for Tensor {
2409    fn is_inf(&self) -> Result<Self> {
2410        self.broadcast_eq(&Tensor::new(f32::INFINITY, self.device())?.to_dtype(self.dtype())?)
2411    }
2412
2413    fn any(&self) -> Result<bool> {
2414        let sum = self.sum_all()?;
2415        match self.dtype() {
2416            DType::U8 => Ok(sum.to_scalar::<u8>()? == 0),
2417            DType::U32 => Ok(sum.to_scalar::<u32>()? == 0),
2418            DType::I16 => Ok(sum.to_scalar::<i16>()? == 0),
2419            DType::I32 => Ok(sum.to_scalar::<i32>()? == 0),
2420            DType::I64 => Ok(sum.to_scalar::<i64>()? == 0),
2421            DType::F16 => Ok(sum.to_scalar::<half::f16>()? == half::f16::from_f32_const(0.)),
2422            DType::BF16 => Ok(sum.to_scalar::<half::bf16>()? == half::bf16::from_f32_const(0.)),
2423            DType::F32 => Ok(sum.to_scalar::<f32>()? == 0.),
2424            DType::F64 => Ok(sum.to_scalar::<f64>()? == 0.),
2425            DType::F8E4M3 => Ok(sum.to_scalar::<F8E4M3>()? == F8E4M3::ZERO),
2426            DType::F4 | DType::F6E3M2 | DType::F6E2M3 | DType::F8E8M0 => {
2427                candle_core::bail!("f4/f6e3m2/f6e2m3/f8e8m0 tensors are not supported with .any")
2428            }
2429        }
2430    }
2431}
2432
2433pub fn clamp_for_f16(xs: &Tensor) -> Result<Tensor> {
2434    let mut max = match xs.dtype() {
2435        DType::U8 => u8::MAX as f32 - 1000.,
2436        DType::U32 => u32::MAX as f32 - 1000.,
2437        DType::I16 => i16::MAX as f32 - 1000.,
2438        DType::I32 => i32::MAX as f32 - 1000.,
2439        DType::I64 => i64::MAX as f32 - 1000.,
2440        DType::F16 => half::f16::MAX.to_f32_const() - 1000.,
2441        DType::BF16 => half::bf16::MAX.to_f32_const() - 1000.,
2442        DType::F32 => f32::MAX - 1000.,
2443        DType::F64 => f64::MAX as f32 - 1000.,
2444        DType::F8E4M3 => F8E4M3::MAX.to_f32() - 1000.,
2445        DType::F4 | DType::F6E3M2 | DType::F6E2M3 | DType::F8E8M0 => {
2446            candle_core::bail!("f4/f6e3m2/f6e2m3/f8e8m0 tensors are not supported with .any")
2447        }
2448    };
2449    if xs.is_inf()?.any()? {
2450        max -= 1000.;
2451    }
2452    xs.clamp(-max, max)
2453}
2454
2455pub struct FloatInfo {
2456    /// Minimum representable value.
2457    pub min: f64,
2458    /// Maximum representable value.
2459    pub max: f64,
2460    /// The difference between 1.0 and the next smallest representable float larger than 1.0.
2461    pub eps: f64,
2462    pub dtype: DType,
2463}
2464
2465pub trait GetFloatInfo {
2466    fn finfo(&self) -> Result<FloatInfo>;
2467}
2468
2469impl GetFloatInfo for DType {
2470    fn finfo(&self) -> Result<FloatInfo> {
2471        let finfo = match self {
2472            Self::BF16 => FloatInfo {
2473                min: bf16::MIN.to_f64(),
2474                max: bf16::MAX.to_f64(),
2475                eps: bf16::EPSILON.to_f64(),
2476                dtype: DType::BF16,
2477            },
2478            Self::F16 => FloatInfo {
2479                min: f16::MIN.to_f64(),
2480                max: f16::MAX.to_f64(),
2481                eps: f16::EPSILON.to_f64(),
2482                dtype: DType::F16,
2483            },
2484            Self::F32 => FloatInfo {
2485                min: f32::MIN as f64,
2486                max: f32::MAX as f64,
2487                eps: f32::EPSILON as f64,
2488                dtype: DType::F32,
2489            },
2490            Self::F64 => FloatInfo {
2491                min: f64::MIN,
2492                max: f64::MAX,
2493                eps: f64::EPSILON,
2494                dtype: DType::F64,
2495            },
2496            Self::F8E4M3 => FloatInfo {
2497                min: F8E4M3::MIN.to_f64(),
2498                max: F8E4M3::MAX.to_f64(),
2499                eps: F8E4M3::EPSILON.to_f64(),
2500                dtype: DType::F8E4M3,
2501            },
2502            other => {
2503                candle_core::bail!("Expected a float type for `GetFloatInfo`, got {other:?}");
2504            }
2505        };
2506        Ok(finfo)
2507    }
2508}
2509
2510#[derive(Clone)]
2511pub struct Mlp {
2512    pub gate: Arc<dyn QuantMethod>,
2513    pub up: Arc<dyn QuantMethod>,
2514    pub down: Arc<dyn QuantMethod>,
2515    act: Activation,
2516    params: Vec<usize>,
2517}
2518
2519impl Mlp {
2520    pub fn new(
2521        vb: ShardedVarBuilder,
2522        hidden_size: usize,
2523        intermediate_size: usize,
2524        quantization_config: &Option<QuantizedConfig>,
2525        hidden_act: Activation,
2526        comm: &Arc<mistralrs_quant::Comm>,
2527    ) -> Result<Self> {
2528        Ok(Self {
2529            gate: ColumnParallelLayer::new(
2530                hidden_size,
2531                intermediate_size,
2532                quantization_config,
2533                false,
2534                comm,
2535                vb.pp("gate_proj"),
2536            )?,
2537            up: ColumnParallelLayer::new(
2538                hidden_size,
2539                intermediate_size,
2540                quantization_config,
2541                false,
2542                comm,
2543                vb.pp("up_proj"),
2544            )?,
2545            down: RowParallelLayer::new(
2546                intermediate_size,
2547                hidden_size,
2548                quantization_config,
2549                false,
2550                comm,
2551                vb.pp("down_proj"),
2552            )?,
2553            act: hidden_act,
2554            params: vec![hidden_size, intermediate_size],
2555        })
2556    }
2557
2558    pub fn new_merged(
2559        vb: ShardedVarBuilder,
2560        hidden_size: usize,
2561        intermediate_size: usize,
2562        chunks: usize,
2563        quantization_config: &Option<QuantizedConfig>,
2564        hidden_act: Activation,
2565        comm: &Arc<mistralrs_quant::Comm>,
2566    ) -> Result<Self> {
2567        assert!(chunks == 2, "Only gate_up_proj merge is supported!");
2568        let gate_up_projs = ColumnParallelLayer::new_merged(
2569            hidden_size,
2570            intermediate_size * 2,
2571            2,
2572            quantization_config,
2573            false,
2574            comm,
2575            vb.pp("gate_up_proj"),
2576        )?;
2577
2578        Ok(Self {
2579            gate: gate_up_projs[0].to_owned(),
2580            up: gate_up_projs[1].to_owned(),
2581            down: RowParallelLayer::new(
2582                intermediate_size,
2583                hidden_size,
2584                quantization_config,
2585                false,
2586                comm,
2587                vb.pp("down_proj"),
2588            )?,
2589            act: hidden_act,
2590            params: vec![hidden_size, intermediate_size],
2591        })
2592    }
2593
2594    pub fn replicate(
2595        params: &[usize],
2596        vb: ShardedVarBuilder,
2597        act: Activation,
2598        comm: &Arc<mistralrs_quant::Comm>,
2599    ) -> Result<Self> {
2600        Self::new(vb, params[0], params[1], &None, act, comm)
2601    }
2602
2603    pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2604        let original_dtype = xs.dtype();
2605        let mut xs = xs.clone();
2606        if let Some(t) = self.gate.quantized_act_type() {
2607            xs = xs.to_dtype(t)?;
2608        }
2609        let lhs = self.gate.forward(&xs)?;
2610        let rhs = self.up.forward(&xs)?;
2611        let mut res = self
2612            .down
2613            .forward(&crate::ops::mul_and_act(&lhs, &rhs, self.act)?)?;
2614        if self.gate.quantized_act_type().is_some() {
2615            res = res.to_dtype(original_dtype)?;
2616        }
2617        Ok(res)
2618    }
2619}
2620
2621impl AnyMoeTrainableLayer for Mlp {}
2622
2623impl MlpLayer for Mlp {
2624    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2625        let original_dtype = xs.dtype();
2626        let mut xs = xs.clone();
2627        if let Some(t) = self.gate.quantized_act_type() {
2628            xs = xs.to_dtype(t)?;
2629        }
2630        let lhs = MatMul.qmethod_matmul(&xs, &*self.gate)?;
2631        let rhs = MatMul.qmethod_matmul(&xs, &*self.up)?;
2632        let mut res =
2633            MatMul.qmethod_matmul(&crate::ops::mul_and_act(&lhs, &rhs, self.act)?, &*self.down)?;
2634        if self.gate.quantized_act_type().is_some() {
2635            res = res.to_dtype(original_dtype)?;
2636        }
2637        Ok(res)
2638    }
2639    fn get_isq_layers(&mut self) -> Vec<&mut Arc<dyn QuantMethod>> {
2640        vec![&mut self.gate, &mut self.up, &mut self.down]
2641    }
2642    fn clone(&self) -> Box<dyn MlpLayer> {
2643        Box::new(Clone::clone(self))
2644    }
2645    fn get_params(&self) -> &[usize] {
2646        &self.params
2647    }
2648    fn hidden_act(&self) -> Activation {
2649        self.act
2650    }
2651    // gate, up, down
2652    fn new_added_delta(&self, deltas: Vec<Option<Tensor>>) -> Result<Box<dyn MlpLayer>> {
2653        let gate = if let Some(ref delta) = deltas[0] {
2654            self.gate.add_delta_w(delta)?
2655        } else {
2656            self.gate.clone()
2657        };
2658        let up = if let Some(ref delta) = deltas[1] {
2659            self.up.add_delta_w(delta)?
2660        } else {
2661            self.up.clone()
2662        };
2663        let down = if let Some(ref delta) = deltas[2] {
2664            self.down.add_delta_w(delta)?
2665        } else {
2666            self.down.clone()
2667        };
2668
2669        Ok(Box::new(Self {
2670            gate,
2671            up,
2672            down,
2673            act: self.act,
2674            params: self.params.clone(),
2675        }))
2676    }
2677
2678    fn dtype_device(&self) -> (DType, Device) {
2679        self.gate.dtype_and_device()
2680    }
2681}
2682
2683pub struct AvgPool2d {
2684    kernel_size: usize,
2685    stride: usize,
2686}
2687
2688impl AvgPool2d {
2689    pub fn new(kernel_size: usize, stride: usize) -> Self {
2690        Self {
2691            kernel_size,
2692            stride,
2693        }
2694    }
2695
2696    pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2697        xs.avg_pool2d_with_stride(self.kernel_size, self.stride)
2698    }
2699}
2700
2701/// Applies 2D reflection padding to a tensor of shape (N, C, H, W).
2702///
2703/// The `padding` argument is a 4-tuple (pad_left, pad_right, pad_top, pad_bottom).
2704/// For left padding, it reflects the values from column 1 up to pad_left (in reverse order);
2705/// for right padding, it reflects from the second-to-last column backwards, and similarly for
2706/// vertical (height) padding.
2707pub struct ReflectionPad2d {
2708    padding: (usize, usize, usize, usize),
2709}
2710
2711impl ReflectionPad2d {
2712    pub fn new(padding: (usize, usize, usize, usize)) -> Self {
2713        Self { padding }
2714    }
2715}
2716
2717impl Module for ReflectionPad2d {
2718    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2719        let (pad_left, pad_right, pad_top, pad_bottom) = self.padding;
2720
2721        let (_n, _c, h, w) = xs.dims4()?;
2722
2723        // --- Horizontal Padding (along width, axis = 3) ---
2724        // For left padding, we reflect columns 1..=pad_left (in reverse order).
2725        let left_pad = if pad_left > 0 {
2726            // Create indices: [pad_left, pad_left-1, ..., 1]
2727            let indices: Vec<i64> = (1..=pad_left as i64).rev().collect();
2728            Some(xs.index_select(&Tensor::new(indices, &Device::Cpu)?, 3)?)
2729        } else {
2730            None
2731        };
2732
2733        // For right padding, we reflect from the right side (excluding the last column).
2734        let right_pad = if pad_right > 0 {
2735            // For pad_right == 2, generate indices: [w-2, w-3, ... , w-1-pad_right]
2736            let start = w as i64 - 2;
2737            let indices: Vec<i64> = (0..pad_right as i64).map(|i| start - i).collect();
2738            Some(xs.index_select(&Tensor::new(indices, &Device::Cpu)?, 3)?)
2739        } else {
2740            None
2741        };
2742
2743        // Concatenate horizontally (along width, dim=3)
2744        let x_padded_width = match (left_pad, right_pad) {
2745            (Some(l), Some(r)) => Tensor::cat(&[l, xs.clone(), r], 3)?,
2746            (Some(l), None) => Tensor::cat(&[l, xs.clone()], 3)?,
2747            (None, Some(r)) => Tensor::cat(&[xs.clone(), r], 3)?,
2748            (None, None) => xs.clone(),
2749        };
2750
2751        // --- Vertical Padding (along height, axis = 2) ---
2752        // For top padding, reflect rows 1..=pad_top (in reverse order)
2753        let top_pad = if pad_top > 0 {
2754            let indices: Vec<i64> = (1..=pad_top as i64).rev().collect();
2755            Some(x_padded_width.index_select(&Tensor::new(indices, &Device::Cpu)?, 2)?)
2756        } else {
2757            None
2758        };
2759
2760        // For bottom padding, reflect from the bottom (excluding the last row)
2761        let bottom_pad = if pad_bottom > 0 {
2762            let start = h as i64 - 2;
2763            let indices: Vec<i64> = (0..pad_bottom as i64).map(|i| start - i).collect();
2764            Some(x_padded_width.index_select(&Tensor::new(indices, &Device::Cpu)?, 2)?)
2765        } else {
2766            None
2767        };
2768
2769        // Concatenate vertically (along height, dim=2)
2770        let x_padded = match (top_pad, bottom_pad) {
2771            (Some(t), Some(b)) => Tensor::cat(&[t, x_padded_width, b], 2)?,
2772            (Some(t), None) => Tensor::cat(&[t, x_padded_width], 2)?,
2773            (None, Some(b)) => Tensor::cat(&[x_padded_width, b], 2)?,
2774            (None, None) => x_padded_width,
2775        };
2776
2777        Ok(x_padded)
2778    }
2779}
2780
2781pub struct ScaledEmbedding {
2782    scale: f64,
2783    pub embedding: Tensor,
2784}
2785
2786impl ScaledEmbedding {
2787    pub fn new(scale: f64, embedding: Embedding) -> Self {
2788        Self {
2789            scale,
2790            embedding: embedding.embeddings().clone(),
2791        }
2792    }
2793
2794    pub fn embeddings(&self) -> &Tensor {
2795        &self.embedding
2796    }
2797}
2798
2799impl Module for ScaledEmbedding {
2800    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2801        let embedding = Embedding::new(self.embedding.clone(), self.embedding.dim(D::Minus1)?);
2802        xs.apply(&embedding)? * self.scale
2803    }
2804}