mistralrs_core/
layers.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{f32::consts::PI, ops::Mul, str::FromStr, sync::Arc};
4
5use candle_core::{
6    quantized::{QMatMul, QTensor},
7    Context, DType, Device, IndexOp, Result, Tensor, D,
8};
9use candle_nn::{
10    Conv2d, Conv2dConfig, Embedding, GroupNorm, LayerNorm, LayerNormConfig, Linear, Module,
11};
12use float8::F8E4M3;
13use half::{bf16, f16};
14use mistralrs_quant::{
15    AfqLayer, ColumnParallelLayer, QuantMethod, QuantizedConfig, RowParallelLayer,
16    ShardedVarBuilder,
17};
18use serde::{Deserialize, Serialize};
19
20pub use crate::attention::Sdpa;
21pub use crate::layers_masker::CausalMasker;
22pub use crate::layers_utils::repeat_kv;
23use crate::{
24    amoe::{AnyMoeTrainableLayer, MlpLayer},
25    gguf::Content,
26    models::llama,
27    ops::SplitOp,
28    vision_models::{
29        gemma3::config::Gemma3TextConfig,
30        mllama::{MLlamaRopeScaling, MLlamaRopeType, MLlamaTextConfig},
31        phi4::Phi4MMConfig,
32    },
33};
34
35pub use mistralrs_quant::MatMul;
36
37pub fn embedding(
38    in_size: usize,
39    out_size: usize,
40    vb: ShardedVarBuilder,
41    config: &Option<QuantizedConfig>,
42) -> Result<Embedding> {
43    // AFQ quantized applies quantization to the embeddings.
44    let embeddings = if let Some(QuantizedConfig::Afq { .. }) = config {
45        let afq_layer =
46            AfqLayer::afq_linear_b(out_size, in_size, config.as_ref().unwrap(), false, vb)?;
47        afq_layer.dequantize_w()?
48    } else {
49        vb.get_with_hints((in_size, out_size), "weight", Default::default())?
50    };
51    Ok(Embedding::new(embeddings, out_size))
52}
53
54pub fn layer_norm<C: Into<LayerNormConfig>>(
55    size: usize,
56    config: C,
57    vb: ShardedVarBuilder,
58) -> Result<LayerNorm> {
59    let config = config.into();
60    let weight = vb.get(size, "weight")?;
61    if config.affine {
62        let bias = vb.get(size, "bias")?;
63        Ok(LayerNorm::new(weight, bias, config.eps))
64    } else {
65        Ok(LayerNorm::new_no_bias(weight, config.eps))
66    }
67}
68
69pub fn group_norm(
70    num_groups: usize,
71    num_channels: usize,
72    eps: f64,
73    vb: ShardedVarBuilder,
74) -> Result<GroupNorm> {
75    let weight = vb.get(num_channels, "weight")?;
76    let bias = vb.get(num_channels, "bias")?;
77    GroupNorm::new(weight, bias, num_channels, num_groups, eps)
78}
79
80pub fn conv2d(
81    in_channels: usize,
82    out_channels: usize,
83    kernel_size: usize,
84    cfg: Conv2dConfig,
85    vb: ShardedVarBuilder,
86) -> Result<Conv2d> {
87    let ws = vb.get(
88        (
89            out_channels,
90            in_channels / cfg.groups,
91            kernel_size,
92            kernel_size,
93        ),
94        "weight",
95    )?;
96    let bs = vb.get(out_channels, "bias")?;
97    Ok(Conv2d::new(ws, Some(bs), cfg))
98}
99
100pub fn conv2d_no_bias(
101    in_channels: usize,
102    out_channels: usize,
103    kernel_size: usize,
104    cfg: Conv2dConfig,
105    vb: ShardedVarBuilder,
106) -> Result<Conv2d> {
107    let ws = vb.get(
108        (
109            out_channels,
110            in_channels / cfg.groups,
111            kernel_size,
112            kernel_size,
113        ),
114        "weight",
115    )?;
116    Ok(Conv2d::new(ws, None, cfg))
117}
118
119pub fn linear(in_dim: usize, out_dim: usize, vb: ShardedVarBuilder) -> Result<Linear> {
120    let ws = vb.get((out_dim, in_dim), "weight")?;
121    let bs = vb.get(out_dim, "bias")?;
122    Ok(Linear::new(ws, Some(bs)))
123}
124
125pub fn linear_no_bias(in_dim: usize, out_dim: usize, vb: ShardedVarBuilder) -> Result<Linear> {
126    let ws = vb.get((out_dim, in_dim), "weight")?;
127    Ok(Linear::new(ws, None))
128}
129
130pub fn linear_b(
131    in_dim: usize,
132    out_dim: usize,
133    bias: bool,
134    vb: ShardedVarBuilder,
135) -> Result<Linear> {
136    if bias {
137        linear(in_dim, out_dim, vb)
138    } else {
139        linear_no_bias(in_dim, out_dim, vb)
140    }
141}
142
143#[derive(Debug, Clone)]
144pub struct RmsNorm {
145    eps: f64,
146    weight: Tensor,
147}
148
149impl RmsNorm {
150    pub fn new(size: usize, eps: f64, vb: ShardedVarBuilder) -> Result<Self> {
151        let w = vb.get(size, "weight")?;
152        Ok(Self { eps, weight: w })
153    }
154
155    /// Gemma uses weight + 1.0
156    pub fn new_gemma(size: usize, eps: f64, vb: ShardedVarBuilder) -> Result<Self> {
157        let w = vb.get(size, "weight")?;
158        let w = (w + 1.0)?;
159        Ok(Self { eps, weight: w })
160    }
161
162    /// Gemma uses weight + 1.0. Undo for UQFF generation.
163    pub fn undo_gemma(&self) -> Result<Self> {
164        Ok(Self {
165            eps: self.eps,
166            weight: (&self.weight - 1.0)?,
167        })
168    }
169
170    pub fn from_w(w: Tensor, eps: f64) -> Result<Self> {
171        Ok(Self { eps, weight: w })
172    }
173
174    pub fn weight(&self) -> &Tensor {
175        &self.weight
176    }
177}
178
179impl Module for RmsNorm {
180    fn forward(&self, x: &Tensor) -> Result<Tensor> {
181        candle_nn::ops::rms_norm(&x.contiguous()?, &self.weight, self.eps as f32)
182    }
183}
184
185#[derive(Debug, Clone)]
186pub struct F32RmsNorm {
187    w: Tensor,
188    eps: f64,
189}
190
191impl F32RmsNorm {
192    pub fn new(size: usize, eps: f64, vb: ShardedVarBuilder) -> Result<Self> {
193        Ok(Self {
194            w: vb.get((size,), "weight")?,
195            eps,
196        })
197    }
198
199    pub fn weight(&self) -> &Tensor {
200        &self.w
201    }
202}
203
204impl Module for F32RmsNorm {
205    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
206        let initial_type = xs.dtype();
207        let mut xs = xs.to_dtype(DType::F32)?;
208        let var = xs.powf(2.)?.mean_keepdim(D::Minus1)?;
209        xs = xs.broadcast_mul(&(&var + self.eps)?.recip()?.sqrt()?)?;
210        xs.to_dtype(initial_type)?.broadcast_mul(&self.w)
211    }
212}
213
214#[derive(Debug, Clone)]
215pub struct QRmsNorm {
216    eps: f64,
217    weight: Tensor,
218}
219
220impl QRmsNorm {
221    pub fn new(scale: QTensor, eps: f32) -> Result<Self> {
222        let scale = scale.dequantize(&scale.device())?;
223        Ok(Self {
224            eps: eps as f64,
225            weight: scale,
226        })
227    }
228
229    pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
230        candle_nn::ops::rms_norm(&x.contiguous()?, &self.weight, self.eps as f32)
231    }
232}
233
234/// RoPE supporting LongRope
235#[derive(Debug, Clone)]
236pub struct PhiRotaryEmbedding {
237    short_sin: Tensor,
238    short_cos: Tensor,
239    long_cos: Option<Tensor>,
240    long_sin: Option<Tensor>,
241    original_max_position_embeddings: usize,
242}
243
244#[derive(Debug, Clone, Deserialize, Serialize)]
245#[serde(rename_all = "lowercase")]
246pub enum ScaledRopeType {
247    #[serde(alias = "su")]
248    #[serde(alias = "longrope")]
249    Su,
250    #[serde(alias = "yarn")]
251    Yarn,
252    #[serde(alias = "dynamic")]
253    Dynamic,
254    #[serde(alias = "linear")]
255    Linear,
256}
257
258impl FromStr for ScaledRopeType {
259    type Err = candle_core::Error;
260    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
261        match s {
262            "su" | "longrope" => Ok(Self::Su),
263            "yarn" => Ok(Self::Yarn),
264            "linear" => Ok(Self::Linear),
265            "dynamic" => Ok(Self::Dynamic),
266            _ => Err(candle_core::Error::Msg(
267                "Expected either `su` or `yarn` scaled RoPE type.".to_string(),
268            )),
269        }
270    }
271}
272
273#[derive(Debug, Clone, Deserialize, Serialize)]
274#[serde(untagged)]
275pub enum PhiRopeScalingConfig {
276    Classic {
277        short_factor: Vec<f64>,
278        long_factor: Vec<f64>,
279        #[serde(rename = "type")]
280        scaling_type: ScaledRopeType,
281    },
282    Scaled {
283        short_factor: Vec<f64>,
284        long_factor: Vec<f64>,
285        #[serde(rename = "type")]
286        scaling_type: ScaledRopeType,
287        long_mscale: f64,
288        short_mscale: f64,
289    },
290}
291
292pub struct PhiRopeConfig {
293    pub rope_scaling: Option<PhiRopeScalingConfig>,
294    pub max_position_embeddings: usize,
295    pub original_max_position_embeddings: usize,
296    pub rope_theta: f64,
297    pub head_dim: usize,
298    pub partial_rotary_factor: Option<f64>,
299}
300
301impl PhiRotaryEmbedding {
302    fn new_classic_scaled(
303        short_factor: &[f64],
304        long_factor: &[f64],
305        scaling_type: &ScaledRopeType,
306        cfg: &PhiRopeConfig,
307        dtype: DType,
308        dev: &Device,
309    ) -> Result<Self> {
310        let max_seq_len = cfg.max_position_embeddings;
311        let dim = (cfg.head_dim as f64 * cfg.partial_rotary_factor.unwrap_or(1.)) as usize;
312
313        // Calculate scale
314        let scale =
315            cfg.max_position_embeddings as f64 / cfg.original_max_position_embeddings as f64;
316        let scaling_factor = if scale <= 1.0 {
317            1.0
318        } else {
319            match scaling_type {
320                ScaledRopeType::Su => {
321                    (1.0 + scale.ln() / (cfg.original_max_position_embeddings as f64).ln()).sqrt()
322                }
323                ScaledRopeType::Yarn => 0.1 * scale.ln() + 1.0,
324                _ => candle_core::bail!("Expected either `su` or `yarn` RoPE"),
325            }
326        };
327
328        // Calculate inv freqs for short, long
329        let inv_freq_long = (0..dim)
330            .step_by(2)
331            .enumerate()
332            .map(|(k, i)| {
333                (1f64 / (long_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64))) as f32
334            })
335            .collect::<Vec<_>>();
336        let inv_freq_short = (0..dim)
337            .step_by(2)
338            .enumerate()
339            .map(|(k, i)| {
340                (1f64 / (short_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64))) as f32
341            })
342            .collect::<Vec<_>>();
343        let inv_freq_len = inv_freq_long.len();
344
345        let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
346            .to_dtype(DType::F32)?
347            .reshape((max_seq_len, 1))?;
348
349        // Calculate sin,cos for long
350        let inv_freq_long = Tensor::from_vec(inv_freq_long, (1, inv_freq_len), dev)?;
351        let freqs_long = t.matmul(&inv_freq_long)?;
352        let long_sin = freqs_long.sin()?.mul(scaling_factor)?.to_dtype(dtype)?;
353        let long_cos = freqs_long.cos()?.mul(scaling_factor)?.to_dtype(dtype)?;
354
355        // Calculate sin,cos for short
356        let inv_freq_short =
357            Tensor::from_vec(inv_freq_short, (1, inv_freq_len), dev)?.to_dtype(DType::F32)?;
358        let freqs_short = t.matmul(&inv_freq_short)?;
359        let short_sin = freqs_short.sin()?.mul(scaling_factor)?.to_dtype(dtype)?;
360        let short_cos = freqs_short.cos()?.mul(scaling_factor)?.to_dtype(dtype)?;
361
362        Ok(Self {
363            short_cos,
364            short_sin,
365            long_cos: Some(long_cos),
366            long_sin: Some(long_sin),
367            original_max_position_embeddings: cfg.original_max_position_embeddings,
368        })
369    }
370
371    fn new_unscaled(cfg: &PhiRopeConfig, dtype: DType, dev: &Device) -> Result<Self> {
372        let max_seq_len = cfg.max_position_embeddings;
373        let dim = (cfg.head_dim as f64 * cfg.partial_rotary_factor.unwrap_or(1.)) as usize;
374
375        let inv_freq: Vec<_> = (0..dim)
376            .step_by(2)
377            .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
378            .collect();
379        let inv_freq_len = inv_freq.len();
380        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
381        let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
382            .to_dtype(DType::F32)?
383            .reshape((max_seq_len, 1))?;
384        let freqs = t.matmul(&inv_freq)?;
385        let sin = freqs.sin()?.to_dtype(dtype)?;
386        let cos = freqs.cos()?.to_dtype(dtype)?;
387        Ok(Self {
388            short_cos: cos,
389            short_sin: sin,
390            long_cos: None,
391            long_sin: None,
392            original_max_position_embeddings: cfg.original_max_position_embeddings,
393        })
394    }
395
396    #[allow(clippy::too_many_arguments)]
397    fn new_scaled(
398        short_factor: &[f64],
399        long_factor: &[f64],
400        scaling_type: &ScaledRopeType,
401        long_mscale: f64,
402        short_mscale: f64,
403        cfg: &PhiRopeConfig,
404        dtype: DType,
405        dev: &Device,
406    ) -> Result<Self> {
407        let max_seq_len = cfg.max_position_embeddings;
408        let dim = (cfg.head_dim as f64 * cfg.partial_rotary_factor.unwrap_or(1.)) as usize;
409
410        if !matches!(scaling_type, ScaledRopeType::Su) {
411            candle_core::bail!("Scaled Phi3 RoPE (non-classic scaled, with mscales) must have type `su`/`longrope`.");
412        }
413
414        if short_factor.len() != dim / 2 {
415            candle_core::bail!(
416                "Misaligned length {}, expected {} for `su`/`longrope` short rescale factors",
417                short_factor.len(),
418                dim / 2
419            );
420        }
421        if long_factor.len() != dim / 2 {
422            candle_core::bail!(
423                "Misaligned length {}, expected {} for `su`/`longrope` long rescale factors",
424                long_factor.len(),
425                dim / 2
426            );
427        }
428
429        // Short cos/sin
430        let inv_freq_short: Vec<_> = (0..dim)
431            .step_by(2)
432            .enumerate()
433            .map(|(k, i)| {
434                1f32 / (short_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32
435            })
436            .collect();
437        let inv_freq_len_short = inv_freq_short.len();
438        let inv_freq_short = Tensor::from_vec(inv_freq_short, (1, inv_freq_len_short), dev)?;
439        let t_short = Tensor::arange(0u32, max_seq_len as u32, dev)?
440            .to_dtype(DType::F32)?
441            .reshape((max_seq_len, 1))?;
442        let freqs_short = t_short.matmul(&inv_freq_short)?;
443        let sin_short = (freqs_short.sin()?.to_dtype(dtype)? * short_mscale)?;
444        let cos_short = (freqs_short.cos()?.to_dtype(dtype)? * short_mscale)?;
445
446        // Long cos/sin
447        let inv_freq_long: Vec<_> = (0..dim)
448            .step_by(2)
449            .enumerate()
450            .map(|(k, i)| {
451                1f32 / (long_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32
452            })
453            .collect();
454        let inv_freq_len_long = inv_freq_long.len();
455        let inv_freq_long = Tensor::from_vec(inv_freq_long, (1, inv_freq_len_long), dev)?;
456        let t_long = Tensor::arange(0u32, max_seq_len as u32, dev)?
457            .to_dtype(DType::F32)?
458            .reshape((max_seq_len, 1))?;
459        let freqs_long = t_long.matmul(&inv_freq_long)?;
460        let sin_long = (freqs_long.sin()?.to_dtype(dtype)? * long_mscale)?;
461        let cos_long = (freqs_long.cos()?.to_dtype(dtype)? * long_mscale)?;
462        Ok(Self {
463            short_cos: cos_short,
464            short_sin: sin_short,
465            long_cos: Some(cos_long),
466            long_sin: Some(sin_long),
467            original_max_position_embeddings: cfg.original_max_position_embeddings,
468        })
469    }
470
471    pub fn new(dtype: DType, cfg: impl Into<PhiRopeConfig>, dev: &Device) -> Result<Self> {
472        let cfg: PhiRopeConfig = cfg.into();
473
474        match &cfg.rope_scaling {
475            Some(PhiRopeScalingConfig::Classic {
476                short_factor,
477                long_factor,
478                scaling_type,
479            }) => {
480                Self::new_classic_scaled(short_factor, long_factor, scaling_type, &cfg, dtype, dev)
481            }
482
483            Some(PhiRopeScalingConfig::Scaled {
484                short_factor,
485                long_factor,
486                scaling_type,
487                long_mscale,
488                short_mscale,
489            }) => Self::new_scaled(
490                short_factor,
491                long_factor,
492                scaling_type,
493                *long_mscale,
494                *short_mscale,
495                &cfg,
496                dtype,
497                dev,
498            ),
499
500            None => Self::new_unscaled(&cfg, dtype, dev),
501        }
502    }
503
504    /// Returns (sin, cos) taking into account LongRope
505    fn get_long_or_short_sin_cos(&self, position_ids: &[usize]) -> (&Tensor, &Tensor) {
506        if self.long_cos.is_none() {
507            return (&self.short_sin, &self.short_cos);
508        }
509        let seq_len = position_ids.iter().max().unwrap() + 1;
510        if seq_len > self.original_max_position_embeddings {
511            (
512                self.long_sin.as_ref().unwrap(),
513                self.long_cos.as_ref().unwrap(),
514            )
515        } else {
516            (&self.short_sin, &self.short_cos)
517        }
518    }
519
520    pub fn forward(
521        &self,
522        q: &Tensor,
523        k: &Tensor,
524        seqlen_offsets: &[usize],
525        position_ids: &[usize],
526    ) -> Result<(Tensor, Tensor)> {
527        let (sin, cos) = self.get_long_or_short_sin_cos(position_ids);
528        let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
529
530        let rot_dim = cos.dim(D::Minus1)? * 2;
531
532        // Case for Phi 3 / Phi 4 mini
533        if rot_dim != q.dim(D::Minus1)? {
534            let rot_dim = cos.dim(D::Minus1)? * 2;
535            let q_rot = q.narrow(D::Minus1, 0, rot_dim)?;
536            let q_pass = q.narrow(D::Minus1, rot_dim, q.dim(D::Minus1)? - rot_dim)?;
537            let k_rot = k.narrow(D::Minus1, 0, rot_dim)?;
538            let k_pass = k.narrow(D::Minus1, rot_dim, k.dim(D::Minus1)? - rot_dim)?;
539
540            let (q_rot, k_rot) = if seqlen_offsets.len() == 1 {
541                let cos = cos.narrow(0, seqlen_offsets[0], seq_len)?;
542                let sin = sin.narrow(0, seqlen_offsets[0], seq_len)?;
543                let q_embed = candle_nn::rotary_emb::rope(&q_rot.contiguous()?, &cos, &sin)?;
544                let k_embed = candle_nn::rotary_emb::rope(&k_rot.contiguous()?, &cos, &sin)?;
545                (q_embed, k_embed)
546            } else {
547                let mut q_embeds = Vec::new();
548                let mut k_embeds = Vec::new();
549                for (i, offset) in seqlen_offsets.iter().enumerate() {
550                    let cos = cos.narrow(0, *offset, seq_len)?;
551                    let sin = sin.narrow(0, *offset, seq_len)?;
552                    let q_embed = candle_nn::rotary_emb::rope(
553                        &q_rot.i(i)?.unsqueeze(0)?.contiguous()?,
554                        &cos,
555                        &sin,
556                    )?;
557                    let k_embed = candle_nn::rotary_emb::rope(
558                        &k_rot.i(i)?.unsqueeze(0)?.contiguous()?,
559                        &cos,
560                        &sin,
561                    )?;
562                    q_embeds.push(q_embed);
563                    k_embeds.push(k_embed);
564                }
565                let q_rot = Tensor::cat(&q_embeds, 0)?;
566                let k_rot = Tensor::cat(&k_embeds, 0)?;
567                (q_rot, k_rot)
568            };
569
570            Ok((
571                Tensor::cat(&[q_rot, q_pass], D::Minus1)?.contiguous()?,
572                Tensor::cat(&[k_rot, k_pass], D::Minus1)?.contiguous()?,
573            ))
574        } else if seqlen_offsets.len() == 1 {
575            let cos = cos.narrow(0, seqlen_offsets[0], seq_len)?;
576            let sin = sin.narrow(0, seqlen_offsets[0], seq_len)?;
577            let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
578            let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
579            Ok((q_embed, k_embed))
580        } else {
581            let mut q_embeds = Vec::new();
582            let mut k_embeds = Vec::new();
583            for (i, offset) in seqlen_offsets.iter().enumerate() {
584                let cos = cos.narrow(0, *offset, seq_len)?;
585                let sin = sin.narrow(0, *offset, seq_len)?;
586                let q_embed =
587                    candle_nn::rotary_emb::rope(&q.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
588                let k_embed =
589                    candle_nn::rotary_emb::rope(&k.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
590                q_embeds.push(q_embed);
591                k_embeds.push(k_embed);
592            }
593            Ok((Tensor::cat(&q_embeds, 0)?, Tensor::cat(&k_embeds, 0)?))
594        }
595    }
596}
597
598/// RoPE for Llama3
599#[derive(Debug, Clone)]
600pub struct Llama3RotaryEmbedding(RotaryEmbedding);
601
602#[derive(Debug, Clone, Deserialize, Serialize, Default)]
603pub enum Llama3RopeType {
604    #[serde(rename = "llama3")]
605    Llama3,
606    #[default]
607    #[serde(rename = "default")]
608    Default,
609}
610
611#[derive(Debug, Clone, Deserialize, Serialize, Default)]
612pub struct Llama3RopeConfig {
613    pub factor: f32,
614    pub low_freq_factor: f32,
615    pub high_freq_factor: f32,
616    pub original_max_position_embeddings: usize,
617    pub rope_type: Llama3RopeType,
618}
619
620fn calculate_default_inv_freq(cfg: &llama::Config) -> Vec<f32> {
621    let head_dim = cfg.hidden_size / cfg.num_attention_heads;
622    (0..head_dim)
623        .step_by(2)
624        .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
625        .collect()
626}
627
628// https://github.com/huggingface/transformers/blob/1392a6867f40a55dfabaf306745c67627598b1af/src/transformers/modeling_rope_utils.py#L298
629impl Llama3RotaryEmbedding {
630    pub fn new_llama3(
631        dtype: DType,
632        cfg: &llama::Config,
633        dev: &Device,
634        is_gpt_neox: bool,
635    ) -> Result<Self> {
636        match &cfg.rope_scaling {
637            None
638            | Some(Llama3RopeConfig {
639                rope_type: Llama3RopeType::Default,
640                ..
641            }) => Ok(Self(RotaryEmbedding::new(
642                cfg.rope_theta,
643                cfg.hidden_size / cfg.num_attention_heads,
644                cfg.max_position_embeddings,
645                dev,
646                is_gpt_neox,
647                dtype,
648            )?)),
649            Some(rope_scaling) => {
650                let low_freq_wavelen = rope_scaling.original_max_position_embeddings as f32
651                    / rope_scaling.low_freq_factor;
652                let high_freq_wavelen = rope_scaling.original_max_position_embeddings as f32
653                    / rope_scaling.high_freq_factor;
654
655                let inv_freq = calculate_default_inv_freq(cfg)
656                    .into_iter()
657                    .map(|freq| {
658                        let wavelen = 2. * PI / freq;
659                        if wavelen < high_freq_wavelen {
660                            freq
661                        } else if wavelen > low_freq_wavelen {
662                            freq / rope_scaling.factor
663                        } else {
664                            let smooth = (rope_scaling.original_max_position_embeddings as f32
665                                / wavelen
666                                - rope_scaling.low_freq_factor)
667                                / (rope_scaling.high_freq_factor - rope_scaling.low_freq_factor);
668                            (1. - smooth) * freq / rope_scaling.factor + smooth * freq
669                        }
670                    })
671                    .collect::<Vec<_>>();
672                let inv_freq_len = inv_freq.len();
673                let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
674
675                let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
676                    .to_dtype(DType::F32)?
677                    .reshape((cfg.max_position_embeddings, 1))?;
678                let freqs = t.matmul(&inv_freq)?;
679                let sin = freqs.sin()?.to_dtype(dtype)?;
680                let cos = freqs.cos()?.to_dtype(dtype)?;
681                Ok(Self(RotaryEmbedding {
682                    sin,
683                    cos,
684                    is_gpt_neox,
685                }))
686            }
687        }
688    }
689
690    pub fn new_mllama3(
691        dtype: DType,
692        cfg: &MLlamaTextConfig,
693        dev: &Device,
694        is_gpt_neox: bool,
695    ) -> Result<Self> {
696        match &cfg.rope_scaling {
697            None
698            | Some(MLlamaRopeScaling {
699                rope_type: MLlamaRopeType::Default,
700                ..
701            }) => Ok(Self(RotaryEmbedding::new(
702                cfg.rope_theta,
703                cfg.hidden_size / cfg.num_attention_heads,
704                cfg.max_position_embeddings,
705                dev,
706                is_gpt_neox,
707                dtype,
708            )?)),
709            Some(MLlamaRopeScaling {
710                rope_type: MLlamaRopeType::Llama3,
711                original_max_position_embeddings,
712                factor,
713                attention_factor: _,
714                beta_fast: _,
715                beta_slow: _,
716                short_factor: _,
717                long_factor: _,
718                low_freq_factor,
719                high_freq_factor,
720            }) => {
721                let factor = factor.context("MLlama Llama3 RoPE needs `factor` parameter.")?;
722                let low_freq_factor = low_freq_factor
723                    .context("MLlama Llama3 RoPE needs `low_freq_factor` parameter.")?;
724                let high_freq_factor = high_freq_factor
725                    .context("MLlama Llama3 RoPE needs `high_freq_factor` parameter.")?;
726
727                let low_freq_wavelen = *original_max_position_embeddings as f32 / low_freq_factor;
728                let high_freq_wavelen = *original_max_position_embeddings as f32 / high_freq_factor;
729
730                let head_dim = cfg.hidden_size / cfg.num_attention_heads;
731
732                let inv_freq = (0..head_dim)
733                    .step_by(2)
734                    .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
735                    .map(|freq| {
736                        let wavelen = 2. * PI / freq;
737                        if wavelen < high_freq_wavelen {
738                            freq
739                        } else if wavelen > low_freq_wavelen {
740                            freq / factor
741                        } else {
742                            let smooth = (*original_max_position_embeddings as f32 / wavelen
743                                - low_freq_factor)
744                                / (high_freq_factor - low_freq_factor);
745                            (1. - smooth) * freq / factor + smooth * freq
746                        }
747                    })
748                    .collect::<Vec<_>>();
749                let inv_freq_len = inv_freq.len();
750                let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
751
752                let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
753                    .to_dtype(DType::F32)?
754                    .reshape((cfg.max_position_embeddings, 1))?;
755                let freqs = t.matmul(&inv_freq)?;
756                let sin = freqs.sin()?.to_dtype(dtype)?;
757                let cos = freqs.cos()?.to_dtype(dtype)?;
758                Ok(Self(RotaryEmbedding {
759                    sin,
760                    cos,
761                    is_gpt_neox,
762                }))
763            }
764            Some(MLlamaRopeScaling {
765                rope_type: other, ..
766            }) => {
767                candle_core::bail!(
768                    "MLlama doesn't support any other RoPE type than `llama3`, got {other:?}"
769                )
770            }
771        }
772    }
773
774    pub fn forward(
775        &self,
776        q: &Tensor,
777        k: &Tensor,
778        seqlen_offsets: &[usize],
779    ) -> Result<(Tensor, Tensor)> {
780        self.0.forward(q, k, seqlen_offsets)
781    }
782}
783
784// https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L107
785#[derive(Debug, Clone)]
786pub struct Qwen2VLRotaryEmbedding {
787    inv_freq: Tensor,
788    mrope_section: Vec<usize>,
789}
790
791impl Qwen2VLRotaryEmbedding {
792    pub fn new(
793        base: f32,
794        head_dim: usize,
795        device: &Device,
796        mrope_section: Vec<usize>,
797    ) -> Result<Self> {
798        let inv_freq: Vec<_> = (0..head_dim)
799            .step_by(2)
800            .map(|i| 1f32 / base.powf(i as f32 / head_dim as f32))
801            .collect();
802        let inv_freq_len = inv_freq.len();
803        let inv_freq = Tensor::from_vec(inv_freq, (inv_freq_len,), device)?.to_dtype(DType::F32)?;
804        Ok(Self {
805            inv_freq,
806            mrope_section,
807        })
808    }
809
810    /// (cos, sin)
811    pub fn compute_cos_sin(&self, position_ids: &Tensor, dtype: DType) -> Result<(Tensor, Tensor)> {
812        let inv_freq_expanded =
813            self.inv_freq
814                .reshape((1, 1, (), 1))?
815                .repeat((3, position_ids.dim(1)?, 1, 1))?;
816        let position_ids_expanded = position_ids.unsqueeze(2)?;
817        let freqs = inv_freq_expanded
818            .matmul(&position_ids_expanded.to_dtype(inv_freq_expanded.dtype())?)?
819            .transpose(2, 3)?;
820        let cos = freqs.cos()?;
821        let sin = freqs.sin()?;
822
823        let cos = Tensor::cat(
824            &cos.split(&self.mrope_section, D::Minus1)?
825                .into_iter()
826                .enumerate()
827                .map(|(i, m)| m.i(i % 3))
828                .collect::<Result<Vec<_>>>()?,
829            D::Minus1,
830        )?
831        .squeeze(0)?
832        .to_dtype(dtype)?
833        .contiguous()?;
834        let sin = Tensor::cat(
835            &sin.split(&self.mrope_section, D::Minus1)?
836                .into_iter()
837                .enumerate()
838                .map(|(i, m)| m.i(i % 3))
839                .collect::<Result<Vec<_>>>()?,
840            D::Minus1,
841        )?
842        .squeeze(0)?
843        .to_dtype(dtype)?
844        .contiguous()?;
845
846        Ok((cos, sin))
847    }
848
849    // https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L203
850    pub fn forward(
851        &self,
852        (cos, sin): &(Tensor, Tensor),
853        q: &mut Tensor,
854        k: &mut Tensor,
855    ) -> Result<()> {
856        *q = candle_nn::rotary_emb::rope(&q.contiguous()?, cos, sin)?;
857        *k = candle_nn::rotary_emb::rope(&k.contiguous()?, cos, sin)?;
858        Ok(())
859    }
860}
861
862#[derive(Debug, Clone)]
863pub struct Qwen2_5VLRotaryEmbedding {
864    inv_freq: Tensor,
865    mrope_section: Vec<usize>,
866}
867
868impl Qwen2_5VLRotaryEmbedding {
869    pub fn new(
870        base: f32,
871        head_dim: usize,
872        device: &Device,
873        mrope_section: Vec<usize>,
874    ) -> Result<Self> {
875        let inv_freq: Vec<_> = (0..head_dim)
876            .step_by(2)
877            .map(|i| 1f32 / base.powf(i as f32 / head_dim as f32))
878            .collect();
879        let inv_freq_len = inv_freq.len();
880        let inv_freq = Tensor::from_vec(inv_freq, (inv_freq_len,), device)?.to_dtype(DType::F32)?;
881        Ok(Self {
882            inv_freq,
883            mrope_section,
884        })
885    }
886
887    /// (cos, sin)
888    pub fn compute_cos_sin(&self, position_ids: &Tensor, dtype: DType) -> Result<(Tensor, Tensor)> {
889        let inv_freq_expanded =
890            self.inv_freq
891                .reshape((1, 1, (), 1))?
892                .repeat((3, position_ids.dim(1)?, 1, 1))?;
893        let position_ids_expanded = position_ids.unsqueeze(2)?;
894        let freqs = inv_freq_expanded
895            .matmul(&position_ids_expanded.to_dtype(inv_freq_expanded.dtype())?)?
896            .transpose(2, 3)?;
897        let cos = freqs.cos()?;
898        let sin = freqs.sin()?;
899
900        let cos = Tensor::cat(
901            &cos.split(&self.mrope_section, D::Minus1)?
902                .into_iter()
903                .enumerate()
904                .map(|(i, m)| m.i(i % 3))
905                .collect::<Result<Vec<_>>>()?,
906            D::Minus1,
907        )?
908        .squeeze(0)?
909        .to_dtype(dtype)?
910        .contiguous()?;
911        let sin = Tensor::cat(
912            &sin.split(&self.mrope_section, D::Minus1)?
913                .into_iter()
914                .enumerate()
915                .map(|(i, m)| m.i(i % 3))
916                .collect::<Result<Vec<_>>>()?,
917            D::Minus1,
918        )?
919        .squeeze(0)?
920        .to_dtype(dtype)?
921        .contiguous()?;
922
923        Ok((cos, sin))
924    }
925
926    pub fn forward(
927        &self,
928        (cos, sin): &(Tensor, Tensor),
929        q: &mut Tensor,
930        k: &mut Tensor,
931    ) -> Result<()> {
932        *q = candle_nn::rotary_emb::rope(&q.contiguous()?, cos, sin)?;
933        *k = candle_nn::rotary_emb::rope(&k.contiguous()?, cos, sin)?;
934        Ok(())
935    }
936}
937
938#[derive(Debug, Clone)]
939pub struct DeepSeekV2RotaryEmbedding {
940    sin: Tensor,
941    cos: Tensor,
942}
943
944#[derive(Debug, Clone, Deserialize, Serialize)]
945#[serde(untagged)]
946pub enum DeepSeekV2RopeScaling {
947    Yarn {
948        original_max_position_embeddings: usize,
949        beta_fast: f32,
950        beta_slow: f32,
951        mscale: f32,
952        mscale_all_dim: f32,
953        factor: f32,
954        #[serde(rename = "type")]
955        scaling_type: ScaledRopeType,
956    },
957    LinearOrDynamic {
958        #[serde(rename = "type")]
959        scaling_type: ScaledRopeType,
960        factor: f64,
961    },
962}
963
964pub struct DeepSeekV2RopeConfig {
965    pub rope_scaling: Option<DeepSeekV2RopeScaling>,
966    pub max_position_embeddings: usize,
967    pub rope_theta: f32,
968    pub qk_rope_head_dim: usize,
969}
970
971impl DeepSeekV2RotaryEmbedding {
972    fn new_unscaled(cfg: &DeepSeekV2RopeConfig, dtype: DType, dev: &Device) -> Result<Self> {
973        let max_seq_len = cfg.max_position_embeddings;
974        let dim = cfg.qk_rope_head_dim;
975
976        let inv_freq: Vec<_> = (0..dim)
977            .step_by(2)
978            .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / dim as f32))
979            .collect();
980        let inv_freq_len = inv_freq.len();
981        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
982        let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
983            .to_dtype(DType::F32)?
984            .reshape((max_seq_len, 1))?;
985        let freqs = t.matmul(&inv_freq)?;
986
987        let sin = freqs.sin()?.to_dtype(dtype)?;
988        let cos = freqs.cos()?.to_dtype(dtype)?;
989
990        Ok(Self { sin, cos })
991    }
992
993    fn yarn_find_correction_dim(
994        num_rot: f32,
995        dim: usize,
996        base: f32,
997        max_position_embeddings: usize,
998    ) -> f32 {
999        (dim as f32 * (max_position_embeddings as f32 / (num_rot * 2. * PI)).ln())
1000            / (2. * base.ln())
1001    }
1002
1003    fn yarn_find_correction_range(
1004        low_rot: f32,
1005        high_rot: f32,
1006        dim: usize,
1007        base: f32,
1008        max_position_embeddings: usize,
1009    ) -> (f32, f32) {
1010        let low =
1011            Self::yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings).floor();
1012        let high =
1013            Self::yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings).ceil();
1014        (low.max(0.), high.min(dim as f32 - 1.))
1015    }
1016
1017    fn yarn_linear_ramp_mask(min: f32, mut max: f32, dim: usize, dev: &Device) -> Result<Tensor> {
1018        if min == max {
1019            // https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite/blob/604d5664dddd88a0433dbae533b7fe9472482de0/modeling_deepseek.py#L255
1020            max += 0.001;
1021        }
1022        let linear_func =
1023            ((Tensor::arange(0f32, dim as f32, dev)? - min as f64)? / (max as f64 - min as f64))?;
1024        linear_func.clamp(0., 1)
1025    }
1026
1027    pub(crate) fn yarn_get_mscale(scale: f32, mscale: f32) -> f32 {
1028        if scale <= 1. {
1029            return 1.;
1030        }
1031        0.1 * mscale * scale.ln() + 1.
1032    }
1033
1034    #[allow(clippy::too_many_arguments)]
1035    fn new_yarn(
1036        cfg: &DeepSeekV2RopeConfig,
1037        dtype: DType,
1038        dev: &Device,
1039        original_max_position_embeddings: usize,
1040        beta_fast: f32,
1041        beta_slow: f32,
1042        factor: f32,
1043        mscale: f32,
1044        mscale_all_dim: f32,
1045    ) -> Result<Self> {
1046        let freq_extra: Vec<_> = (0..cfg.qk_rope_head_dim)
1047            .step_by(2)
1048            .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / cfg.qk_rope_head_dim as f32))
1049            .collect();
1050        let freq_extra_len = freq_extra.len();
1051        let freq_extra = Tensor::from_vec(freq_extra, freq_extra_len, dev)?;
1052        let freq_inter: Vec<_> = (0..cfg.qk_rope_head_dim)
1053            .step_by(2)
1054            .map(|i| 1f32 / (factor * cfg.rope_theta.powf(i as f32 / cfg.qk_rope_head_dim as f32)))
1055            .collect();
1056        let freq_inter_len = freq_inter.len();
1057        let freq_inter = Tensor::from_vec(freq_inter, (1, freq_inter_len), dev)?;
1058
1059        let (low, high) = Self::yarn_find_correction_range(
1060            beta_fast,
1061            beta_slow,
1062            cfg.qk_rope_head_dim,
1063            cfg.rope_theta,
1064            original_max_position_embeddings,
1065        );
1066        let inv_freq_mask =
1067            (1. - Self::yarn_linear_ramp_mask(low, high, cfg.qk_rope_head_dim / 2, dev)?)?;
1068        let inv_freq = freq_inter
1069            .broadcast_mul(&(1. - &inv_freq_mask)?)?
1070            .broadcast_add(&freq_extra.broadcast_mul(&inv_freq_mask)?)?;
1071
1072        let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
1073            .to_dtype(DType::F32)?
1074            .reshape((cfg.max_position_embeddings, 1))?;
1075        let freqs = t.matmul(&inv_freq)?;
1076
1077        let mscale =
1078            Self::yarn_get_mscale(factor, mscale) / Self::yarn_get_mscale(factor, mscale_all_dim);
1079        let sin = (freqs.sin()? * mscale as f64)?.to_dtype(dtype)?;
1080        let cos = (freqs.cos()? * mscale as f64)?.to_dtype(dtype)?;
1081
1082        Ok(Self { sin, cos })
1083    }
1084
1085    pub fn new(cfg: &DeepSeekV2RopeConfig, dtype: DType, dev: &Device) -> Result<Self> {
1086        match &cfg.rope_scaling {
1087            Some(DeepSeekV2RopeScaling::LinearOrDynamic {
1088                scaling_type: _,
1089                factor: _,
1090            }) => candle_core::bail!("linear and dynamic rope are not implemented yet!"),
1091            Some(DeepSeekV2RopeScaling::Yarn {
1092                original_max_position_embeddings,
1093                beta_fast,
1094                beta_slow,
1095                factor,
1096                mscale,
1097                mscale_all_dim,
1098                scaling_type: _,
1099            }) => Self::new_yarn(
1100                cfg,
1101                dtype,
1102                dev,
1103                *original_max_position_embeddings,
1104                *beta_fast,
1105                *beta_slow,
1106                *factor,
1107                *mscale,
1108                *mscale_all_dim,
1109            ),
1110            None => Self::new_unscaled(cfg, dtype, dev),
1111        }
1112    }
1113
1114    pub fn forward(
1115        &self,
1116        q: &Tensor,
1117        k: &Tensor,
1118        seqlen_offsets: &[usize],
1119    ) -> Result<(Tensor, Tensor)> {
1120        let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
1121
1122        if seqlen_offsets.len() == 1 {
1123            let cos = self.cos.narrow(0, seqlen_offsets[0], seq_len)?;
1124            let sin = self.sin.narrow(0, seqlen_offsets[0], seq_len)?;
1125            let q_embed = candle_nn::rotary_emb::rope_i(&q.contiguous()?, &cos, &sin)?;
1126            let k_embed = candle_nn::rotary_emb::rope_i(&k.contiguous()?, &cos, &sin)?;
1127            Ok((q_embed, k_embed))
1128        } else {
1129            let mut q_embeds = Vec::new();
1130            let mut k_embeds = Vec::new();
1131            for (i, offset) in seqlen_offsets.iter().enumerate() {
1132                let cos = self.cos.narrow(0, *offset, seq_len)?;
1133                let sin = self.sin.narrow(0, *offset, seq_len)?;
1134                let q_embed = candle_nn::rotary_emb::rope_i(
1135                    &q.i(i)?.unsqueeze(0)?.contiguous()?,
1136                    &cos,
1137                    &sin,
1138                )?;
1139                let k_embed = candle_nn::rotary_emb::rope_i(
1140                    &k.i(i)?.unsqueeze(0)?.contiguous()?,
1141                    &cos,
1142                    &sin,
1143                )?;
1144                q_embeds.push(q_embed);
1145                k_embeds.push(k_embed);
1146            }
1147            Ok((Tensor::cat(&q_embeds, 0)?, Tensor::cat(&k_embeds, 0)?))
1148        }
1149    }
1150}
1151
1152#[derive(Debug, Clone)]
1153pub struct Phi4MMRotaryEmbedding {
1154    short_sin: Tensor,
1155    short_cos: Tensor,
1156    long_cos: Option<Tensor>,
1157    long_sin: Option<Tensor>,
1158    original_max_position_embeddings: usize,
1159}
1160
1161#[derive(Debug, Clone, Default, Deserialize, Serialize)]
1162#[serde(rename_all = "lowercase")]
1163pub enum Phi4MMScaledRopeType {
1164    #[serde(alias = "longrope")]
1165    LongRope,
1166    #[default]
1167    Default,
1168}
1169
1170#[derive(Debug, Clone, Deserialize, Serialize)]
1171pub struct Phi4MMRopeScalingConfig {
1172    short_factor: Option<Vec<f64>>,
1173    long_factor: Option<Vec<f64>>,
1174    #[serde(rename = "type")]
1175    scaling_type: Phi4MMScaledRopeType,
1176}
1177
1178impl Phi4MMRotaryEmbedding {
1179    fn new_unscaled(cfg: &Phi4MMConfig, dtype: DType, dev: &Device) -> Result<Self> {
1180        let max_seq_len = cfg.max_position_embeddings;
1181        let dim = (cfg.head_dim() as f64 * cfg.partial_rotary_factor) as usize;
1182
1183        let inv_freq: Vec<_> = (0..dim)
1184            .step_by(2)
1185            .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
1186            .collect();
1187        let inv_freq_len = inv_freq.len();
1188        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
1189        let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
1190            .to_dtype(DType::F32)?
1191            .reshape((max_seq_len, 1))?;
1192        let freqs = t.matmul(&inv_freq)?;
1193        let sin = freqs.sin()?.to_dtype(dtype)?;
1194        let cos = freqs.cos()?.to_dtype(dtype)?;
1195        Ok(Self {
1196            short_cos: cos,
1197            short_sin: sin,
1198            long_cos: None,
1199            long_sin: None,
1200            original_max_position_embeddings: cfg.original_max_position_embeddings,
1201        })
1202    }
1203
1204    #[allow(clippy::too_many_arguments)]
1205    fn new_longrope(
1206        short_factor: &[f64],
1207        long_factor: &[f64],
1208        cfg: &Phi4MMConfig,
1209        dtype: DType,
1210        dev: &Device,
1211    ) -> Result<Self> {
1212        let max_seq_len = cfg.max_position_embeddings;
1213        let dim = (cfg.head_dim() as f64 * cfg.partial_rotary_factor) as usize;
1214
1215        // Calculate scale
1216        let scale =
1217            cfg.max_position_embeddings as f64 / cfg.original_max_position_embeddings as f64;
1218        let scaling_factor = if scale <= 1.0 {
1219            1.0
1220        } else {
1221            (1.0 + scale.ln() / (cfg.original_max_position_embeddings as f64).ln()).sqrt()
1222        };
1223
1224        // Short cos/sin
1225        let inv_freq_short: Vec<_> = (0..dim)
1226            .step_by(2)
1227            .enumerate()
1228            .map(|(k, i)| {
1229                1f32 / (short_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32
1230            })
1231            .collect();
1232        let inv_freq_len_short = inv_freq_short.len();
1233        let inv_freq_short = Tensor::from_vec(inv_freq_short, (1, inv_freq_len_short), dev)?;
1234        let t_short = Tensor::arange(0u32, max_seq_len as u32, dev)?
1235            .to_dtype(DType::F32)?
1236            .reshape((max_seq_len, 1))?;
1237        let freqs_short = t_short.matmul(&inv_freq_short)?;
1238        let sin_short = (freqs_short.sin()?.to_dtype(dtype)? * scaling_factor)?;
1239        let cos_short = (freqs_short.cos()?.to_dtype(dtype)? * scaling_factor)?;
1240
1241        // Long cos/sin
1242        let inv_freq_long: Vec<_> = (0..dim)
1243            .step_by(2)
1244            .enumerate()
1245            .map(|(k, i)| {
1246                1f32 / (long_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32
1247            })
1248            .collect();
1249        let inv_freq_len_long = inv_freq_long.len();
1250        let inv_freq_long = Tensor::from_vec(inv_freq_long, (1, inv_freq_len_long), dev)?;
1251        let t_long = Tensor::arange(0u32, max_seq_len as u32, dev)?
1252            .to_dtype(DType::F32)?
1253            .reshape((max_seq_len, 1))?;
1254        let freqs_long = t_long.matmul(&inv_freq_long)?;
1255        let sin_long = (freqs_long.sin()?.to_dtype(dtype)? * scaling_factor)?;
1256        let cos_long = (freqs_long.cos()?.to_dtype(dtype)? * scaling_factor)?;
1257
1258        Ok(Self {
1259            short_cos: cos_short,
1260            short_sin: sin_short,
1261            long_cos: Some(cos_long),
1262            long_sin: Some(sin_long),
1263            original_max_position_embeddings: cfg.original_max_position_embeddings,
1264        })
1265    }
1266
1267    pub fn new(dtype: DType, cfg: &Phi4MMConfig, dev: &Device) -> Result<Self> {
1268        match &cfg.rope_scaling {
1269            Some(Phi4MMRopeScalingConfig {
1270                scaling_type: Phi4MMScaledRopeType::LongRope,
1271                short_factor: Some(short_factor),
1272                long_factor: Some(long_factor),
1273            }) => Self::new_longrope(short_factor, long_factor, cfg, dtype, dev),
1274
1275            _ => Self::new_unscaled(cfg, dtype, dev),
1276        }
1277    }
1278
1279    /// Returns (sin, cos) taking into account LongRope
1280    fn get_long_or_short_sin_cos(&self, position_ids: &[usize]) -> (&Tensor, &Tensor) {
1281        if self.long_cos.is_none() {
1282            return (&self.short_sin, &self.short_cos);
1283        }
1284        let seq_len = position_ids.iter().max().unwrap() + 1;
1285        if seq_len > self.original_max_position_embeddings {
1286            (
1287                self.long_sin.as_ref().unwrap(),
1288                self.long_cos.as_ref().unwrap(),
1289            )
1290        } else {
1291            (&self.short_sin, &self.short_cos)
1292        }
1293    }
1294
1295    pub fn forward(
1296        &self,
1297        q: &Tensor,
1298        k: &Tensor,
1299        seqlen_offsets: &[usize],
1300        position_ids: &[usize],
1301    ) -> Result<(Tensor, Tensor)> {
1302        let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
1303        let (sin, cos) = self.get_long_or_short_sin_cos(position_ids);
1304
1305        let rot_dim = cos.dim(D::Minus1)? * 2;
1306        let q_rot = q.narrow(D::Minus1, 0, rot_dim)?;
1307        let q_pass = q.narrow(D::Minus1, rot_dim, q.dim(D::Minus1)? - rot_dim)?;
1308        let k_rot = k.narrow(D::Minus1, 0, rot_dim)?;
1309        let k_pass = k.narrow(D::Minus1, rot_dim, k.dim(D::Minus1)? - rot_dim)?;
1310
1311        let (q_rot, k_rot) = if seqlen_offsets.len() == 1 {
1312            let cos = cos.narrow(0, seqlen_offsets[0], seq_len)?;
1313            let sin = sin.narrow(0, seqlen_offsets[0], seq_len)?;
1314            let q_embed = candle_nn::rotary_emb::rope(&q_rot.contiguous()?, &cos, &sin)?;
1315            let k_embed = candle_nn::rotary_emb::rope(&k_rot.contiguous()?, &cos, &sin)?;
1316            (q_embed, k_embed)
1317        } else {
1318            let mut q_embeds = Vec::new();
1319            let mut k_embeds = Vec::new();
1320            for (i, offset) in seqlen_offsets.iter().enumerate() {
1321                let cos = cos.narrow(0, *offset, seq_len)?;
1322                let sin = sin.narrow(0, *offset, seq_len)?;
1323                let q_embed = candle_nn::rotary_emb::rope(
1324                    &q_rot.i(i)?.unsqueeze(0)?.contiguous()?,
1325                    &cos,
1326                    &sin,
1327                )?;
1328                let k_embed = candle_nn::rotary_emb::rope(
1329                    &k_rot.i(i)?.unsqueeze(0)?.contiguous()?,
1330                    &cos,
1331                    &sin,
1332                )?;
1333                q_embeds.push(q_embed);
1334                k_embeds.push(k_embed);
1335            }
1336            let q_rot = Tensor::cat(&q_embeds, 0)?;
1337            let k_rot = Tensor::cat(&k_embeds, 0)?;
1338            (q_rot, k_rot)
1339        };
1340
1341        Ok((
1342            Tensor::cat(&[q_rot, q_pass], D::Minus1)?.contiguous()?,
1343            Tensor::cat(&[k_rot, k_pass], D::Minus1)?.contiguous()?,
1344        ))
1345    }
1346}
1347
1348#[derive(Debug, Clone)]
1349pub struct Gemma3RotaryEmbedding(RotaryEmbedding);
1350
1351#[derive(Debug, Clone, Deserialize, Serialize)]
1352#[serde(rename_all = "lowercase")]
1353pub enum Gemmma3ScaledRopeType {
1354    #[serde(alias = "linear")]
1355    Linear,
1356}
1357
1358#[derive(Debug, Clone, Deserialize, Serialize)]
1359pub struct Gemma3RopeScalingConfig {
1360    factor: f64,
1361    rope_type: Gemmma3ScaledRopeType,
1362}
1363
1364impl Gemma3RotaryEmbedding {
1365    fn new_linear(
1366        cfg: &Gemma3TextConfig,
1367        factor: f64,
1368        is_gpt_neox: bool,
1369        dtype: DType,
1370        dev: &Device,
1371    ) -> Result<Self> {
1372        let max_seq_len = cfg.max_position_embeddings;
1373        let dim = cfg.head_dim;
1374
1375        let inv_freq: Vec<_> = (0..dim)
1376            .step_by(2)
1377            .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
1378            .collect();
1379        let inv_freq_len = inv_freq.len();
1380        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
1381        let inv_freq = (inv_freq / factor)?;
1382
1383        let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
1384            .to_dtype(DType::F32)?
1385            .reshape((max_seq_len, 1))?;
1386        let freqs = t.matmul(&inv_freq)?;
1387        let sin = freqs.sin()?.to_dtype(dtype)?;
1388        let cos = freqs.cos()?.to_dtype(dtype)?;
1389        Ok(Self(RotaryEmbedding {
1390            cos,
1391            sin,
1392            is_gpt_neox,
1393        }))
1394    }
1395
1396    pub fn new(
1397        is_gpt_neox: bool,
1398        dtype: DType,
1399        cfg: &Gemma3TextConfig,
1400        dev: &Device,
1401    ) -> Result<Self> {
1402        match &cfg.rope_scaling {
1403            Some(Gemma3RopeScalingConfig {
1404                rope_type: Gemmma3ScaledRopeType::Linear,
1405                factor,
1406            }) => Self::new_linear(cfg, *factor, is_gpt_neox, dtype, dev),
1407
1408            _ => Self::new_linear(cfg, 1.0, is_gpt_neox, dtype, dev),
1409        }
1410    }
1411
1412    pub fn forward(
1413        &self,
1414        q: &Tensor,
1415        k: &Tensor,
1416        seqlen_offsets: &[usize],
1417    ) -> Result<(Tensor, Tensor)> {
1418        self.0.forward(q, k, seqlen_offsets)
1419    }
1420}
1421
1422#[derive(Debug, Clone)]
1423pub struct QLinear {
1424    inner: QMatMul,
1425    bias: Option<Tensor>,
1426    dtype: DType,
1427}
1428
1429impl QLinear {
1430    pub fn new<R: std::io::Read + std::io::Seek>(
1431        ct: &mut Content<'_, R>,
1432        name: &str,
1433        device: &Device,
1434    ) -> Result<Self> {
1435        let w = ct.tensor(&format!("{name}.weight"), device)?;
1436        let b = ct.tensor(&format!("{name}.bias"), device)?;
1437        let inner = QMatMul::from_qtensor(w)?;
1438        let bias = b.dequantize(device)?;
1439        Ok(Self {
1440            inner,
1441            bias: Some(bias),
1442            dtype: DType::F32,
1443        })
1444    }
1445
1446    pub fn from_linear(linear: Linear) -> Self {
1447        Self {
1448            inner: QMatMul::Tensor(linear.weight().clone()),
1449            bias: linear.bias().cloned(),
1450            dtype: linear.weight().dtype(),
1451        }
1452    }
1453
1454    pub fn from_parts(w: Tensor, b: Option<Tensor>) -> Self {
1455        let dtype = w.dtype();
1456        Self {
1457            inner: QMatMul::Tensor(w),
1458            bias: b,
1459            dtype,
1460        }
1461    }
1462
1463    pub fn from_qparts(w: QTensor, b: Option<Tensor>) -> Self {
1464        if let Some(ref b) = b {
1465            assert_eq!(b.dtype(), DType::F32);
1466        }
1467        Self {
1468            inner: QMatMul::QTensor(Arc::new(w)),
1469            bias: b,
1470            dtype: DType::F32,
1471        }
1472    }
1473
1474    pub fn from_old_and_qmatmul(inner: QMatMul, old: &Self) -> Self {
1475        Self {
1476            inner,
1477            bias: old.bias.clone(),
1478            dtype: old.dtype,
1479        }
1480    }
1481
1482    pub fn inner(&mut self) -> &mut QMatMul {
1483        &mut self.inner
1484    }
1485
1486    pub fn inner_ref(&self) -> &QMatMul {
1487        &self.inner
1488    }
1489
1490    pub fn is_quant(&self) -> bool {
1491        matches!(self.inner, QMatMul::QTensor(_))
1492    }
1493
1494    pub fn bias(&self) -> Option<&Tensor> {
1495        self.bias.as_ref()
1496    }
1497
1498    pub fn bias_mut(&mut self) -> Option<&mut Tensor> {
1499        self.bias.as_mut()
1500    }
1501}
1502
1503impl Module for QLinear {
1504    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
1505        let xs = if self.is_quant() {
1506            xs.to_dtype(DType::F32)?
1507        } else {
1508            xs.clone()
1509        };
1510        if let Some(bias) = &self.bias {
1511            self.inner
1512                .forward(&xs)?
1513                .broadcast_add(bias)?
1514                .to_dtype(self.dtype)
1515        } else {
1516            self.inner.forward(&xs)?.to_dtype(self.dtype)
1517        }
1518    }
1519}
1520
1521#[derive(Debug, Clone)]
1522pub struct RotaryEmbedding {
1523    cos: Tensor,
1524    sin: Tensor,
1525    is_gpt_neox: bool,
1526}
1527
1528impl RotaryEmbedding {
1529    pub fn new(
1530        base: f32,
1531        head_dim: usize,
1532        max_position_embeddings: usize,
1533        device: &Device,
1534        is_gpt_neox: bool,
1535        dtype: DType,
1536    ) -> Result<Self> {
1537        let inv_freq: Vec<_> = (0..head_dim)
1538            .step_by(2)
1539            .map(|i| 1f32 / base.powf(i as f32 / head_dim as f32))
1540            .collect();
1541        let inv_freq_len = inv_freq.len();
1542        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), device)?;
1543        let t = Tensor::arange(0u32, max_position_embeddings as u32, device)?
1544            .to_dtype(DType::F32)?
1545            .reshape((max_position_embeddings, 1))?;
1546        let freqs = t.matmul(&inv_freq)?;
1547        let sin = freqs.sin()?.to_dtype(dtype)?;
1548        let cos = freqs.cos()?.to_dtype(dtype)?;
1549
1550        Ok(Self {
1551            cos,
1552            sin,
1553            is_gpt_neox,
1554        })
1555    }
1556
1557    pub fn new_partial(
1558        base: f32,
1559        rot_dim: usize,
1560        max_position_embeddings: usize,
1561        device: &Device,
1562        is_gpt_neox: bool,
1563        dtype: DType,
1564    ) -> Result<Self> {
1565        let inv_freq: Vec<_> = (0..rot_dim)
1566            .step_by(2)
1567            .map(|i| 1f32 / base.powf(i as f32 / rot_dim as f32))
1568            .collect();
1569        let inv_freq_len = inv_freq.len();
1570        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), device)?;
1571        let t = Tensor::arange(0u32, max_position_embeddings as u32, device)?
1572            .to_dtype(DType::F32)?
1573            .reshape((max_position_embeddings, 1))?;
1574        let freqs = t.matmul(&inv_freq)?;
1575        let sin = freqs.sin()?.to_dtype(dtype)?;
1576        let cos = freqs.cos()?.to_dtype(dtype)?;
1577
1578        Ok(Self {
1579            cos,
1580            sin,
1581            is_gpt_neox,
1582        })
1583    }
1584
1585    pub fn forward(
1586        &self,
1587        q: &Tensor,
1588        k: &Tensor,
1589        seqlen_offsets: &[usize],
1590    ) -> Result<(Tensor, Tensor)> {
1591        let (b_sz, qh, seq_len, n_embd) = q.dims4()?;
1592        let (_b_sz, kh, _seq_len, __n_embd) = k.dims4()?;
1593
1594        let rope = if self.is_gpt_neox {
1595            candle_nn::rotary_emb::rope
1596        } else {
1597            candle_nn::rotary_emb::rope_i
1598        };
1599
1600        if cfg!(feature = "cuda") {
1601            let (cos, sin) = if seqlen_offsets.len() == 1 {
1602                (
1603                    self.cos.narrow(0, seqlen_offsets[0], seq_len)?,
1604                    self.sin.narrow(0, seqlen_offsets[0], seq_len)?,
1605                )
1606            } else {
1607                let mut cos_s = Vec::new();
1608                let mut sin_s = Vec::new();
1609                for offset in seqlen_offsets {
1610                    cos_s.push(self.cos.narrow(0, *offset, seq_len)?);
1611                    sin_s.push(self.sin.narrow(0, *offset, seq_len)?);
1612                }
1613                (Tensor::cat(&cos_s, 0)?, Tensor::cat(&sin_s, 0)?)
1614            };
1615
1616            let q_embed = q.transpose(1, 2)?.flatten(0, 1)?;
1617            let k_embed = k.transpose(1, 2)?.flatten(0, 1)?;
1618            mistralrs_quant::rotary::apply_rotary_inplace(
1619                &q_embed,
1620                &k_embed,
1621                &cos,
1622                &sin,
1623                self.is_gpt_neox,
1624            )?;
1625            let mut q = q_embed
1626                .reshape((b_sz, seq_len, qh, n_embd))?
1627                .transpose(1, 2)?;
1628            let mut k = k_embed
1629                .reshape((b_sz, seq_len, kh, n_embd))?
1630                .transpose(1, 2)?;
1631            if !(cfg!(feature = "flash-attn") || cfg!(feature = "flash-attn-v3")) {
1632                q = q.contiguous()?;
1633                k = k.contiguous()?;
1634            }
1635            Ok((q, k))
1636        } else if seqlen_offsets.len() == 1 {
1637            let cos = self.cos.narrow(0, seqlen_offsets[0], seq_len)?;
1638            let sin = self.sin.narrow(0, seqlen_offsets[0], seq_len)?;
1639            let q_embed = rope(&q.contiguous()?, &cos, &sin)?;
1640            let k_embed = rope(&k.contiguous()?, &cos, &sin)?;
1641            Ok((q_embed, k_embed))
1642        } else {
1643            let mut q_embeds = Vec::new();
1644            let mut k_embeds = Vec::new();
1645            for (i, offset) in seqlen_offsets.iter().enumerate() {
1646                let cos = self.cos.narrow(0, *offset, seq_len)?;
1647                let sin = self.sin.narrow(0, *offset, seq_len)?;
1648                let q_embed = rope(&q.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
1649                let k_embed = rope(&k.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
1650                q_embeds.push(q_embed);
1651                k_embeds.push(k_embed);
1652            }
1653            Ok((Tensor::cat(&q_embeds, 0)?, Tensor::cat(&k_embeds, 0)?))
1654        }
1655    }
1656}
1657
1658#[derive(Debug, Clone, Copy, PartialEq, Deserialize, Serialize, Default)]
1659#[serde(rename_all = "lowercase")]
1660pub enum Activation {
1661    #[default]
1662    #[serde(alias = "gelu")]
1663    Gelu,
1664    #[serde(alias = "gelu_new")]
1665    NewGelu,
1666    Relu,
1667    Relu2,
1668    Relu6,
1669    Silu,
1670    Sigmoid,
1671    HardSigmoid,
1672    Swiglu,
1673    Swish,
1674    HardSwish,
1675    Elu(f64),
1676    LeakyRelu(f64),
1677    #[serde(alias = "gelu_pytorch_tanh")]
1678    GeluPytorchTanh,
1679    QuickGelu,
1680}
1681
1682impl Module for Activation {
1683    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
1684        match self {
1685            Self::Gelu => xs.gelu_erf(),
1686            // https://github.com/huggingface/transformers/blob/12f043eaeaabfef6f6efea411d98e6f6d3c094b7/src/transformers/activations.py#L49-L78
1687            Self::NewGelu => xs.gelu(),
1688            Self::Relu => xs.relu(),
1689            Self::Relu2 => xs.relu()?.sqr(),
1690            Self::Relu6 => xs.clamp(0f32, 6f32),
1691            Self::Silu => xs.silu(),
1692            Self::Sigmoid => candle_nn::ops::sigmoid(xs),
1693            Self::HardSigmoid => candle_nn::ops::hard_sigmoid(xs),
1694            Self::Swiglu => candle_nn::ops::swiglu(xs),
1695            Self::Swish => xs * candle_nn::ops::sigmoid(xs)?,
1696            Self::HardSwish => xs * candle_nn::ops::hard_sigmoid(xs)?,
1697            &Self::Elu(alpha) => xs.elu(alpha),
1698            &Self::LeakyRelu(negative_slope) => candle_nn::ops::leaky_relu(xs, negative_slope),
1699            Self::GeluPytorchTanh => xs.gelu(),
1700            Self::QuickGelu => xs * candle_nn::ops::sigmoid(&(xs * 1.702f64)?),
1701        }
1702    }
1703}
1704
1705impl TryInto<candle_nn::Activation> for Activation {
1706    type Error = candle_core::Error;
1707
1708    fn try_into(self) -> Result<candle_nn::Activation> {
1709        match self {
1710            Self::Gelu => Ok(candle_nn::Activation::Gelu),
1711            Self::Relu => Ok(candle_nn::Activation::Relu),
1712            Self::Silu => Ok(candle_nn::Activation::Silu),
1713            Self::NewGelu => Ok(candle_nn::Activation::NewGelu),
1714            Self::Relu2 => Ok(candle_nn::Activation::Relu2),
1715            Self::Relu6 => Ok(candle_nn::Activation::Relu6),
1716            Self::Sigmoid => Ok(candle_nn::Activation::Sigmoid),
1717            Self::HardSigmoid => Ok(candle_nn::Activation::HardSigmoid),
1718            Self::Swiglu => Ok(candle_nn::Activation::Swiglu),
1719            Self::Swish => Ok(candle_nn::Activation::Swish),
1720            Self::HardSwish => Ok(candle_nn::Activation::HardSwish),
1721            Self::Elu(x) => Ok(candle_nn::Activation::Elu(x)),
1722            Self::LeakyRelu(x) => Ok(candle_nn::Activation::LeakyRelu(x)),
1723            Self::GeluPytorchTanh => Ok(candle_nn::Activation::GeluPytorchTanh),
1724            Self::QuickGelu => candle_core::bail!("No mapping to candle_nn for QuickGelu"),
1725        }
1726    }
1727}
1728
1729#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1730pub struct Conv3dConfig {
1731    pub padding: usize,
1732    pub stride: usize,
1733    pub dilation: usize,
1734    pub groups: usize,
1735}
1736
1737impl Default for Conv3dConfig {
1738    fn default() -> Self {
1739        Self {
1740            padding: 0,
1741            stride: 1,
1742            dilation: 1,
1743            groups: 1,
1744        }
1745    }
1746}
1747
1748pub struct Conv3dNoBias {
1749    conv2d_1: Conv2d,
1750    conv2d_2: Conv2d,
1751}
1752
1753impl Conv3dNoBias {
1754    pub fn new(
1755        in_channels: usize,
1756        out_channels: usize,
1757        kernel_sizes: [usize; 3],
1758        cfg: Conv3dConfig,
1759        vb: ShardedVarBuilder,
1760    ) -> Result<Self> {
1761        let ws = vb.get(
1762            (
1763                out_channels,
1764                in_channels / cfg.groups,
1765                kernel_sizes[0],
1766                kernel_sizes[1],
1767                kernel_sizes[2],
1768            ),
1769            "weight",
1770        )?;
1771
1772        // Split on temporal dimension
1773        // https://github.com/pytorch/pytorch/issues/139066
1774
1775        let w1 = ws.i((.., .., 0, .., ..))?;
1776        let w2 = ws.i((.., .., 1, .., ..))?;
1777
1778        let cfg = Conv2dConfig {
1779            padding: cfg.padding,
1780            stride: cfg.stride,
1781            dilation: cfg.dilation,
1782            groups: cfg.groups,
1783        };
1784
1785        Ok(Self {
1786            conv2d_1: Conv2d::new(w1.contiguous()?, None, cfg),
1787            conv2d_2: Conv2d::new(w2.contiguous()?, None, cfg),
1788        })
1789    }
1790}
1791
1792impl Module for Conv3dNoBias {
1793    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
1794        let xs1 = xs.i((.., .., 0, .., ..))?;
1795        let xs2 = xs.i((.., .., 1, .., ..))?;
1796
1797        (self.conv2d_1.forward(&xs1)? + self.conv2d_2.forward(&xs2)?)?.unsqueeze(2)
1798    }
1799}
1800
1801pub trait TensorInfExtend {
1802    fn is_inf(&self) -> Result<Self>
1803    where
1804        Self: Sized;
1805    fn any(&self) -> Result<bool>;
1806}
1807
1808impl TensorInfExtend for Tensor {
1809    fn is_inf(&self) -> Result<Self> {
1810        self.broadcast_eq(&Tensor::new(f32::INFINITY, self.device())?.to_dtype(self.dtype())?)
1811    }
1812
1813    fn any(&self) -> Result<bool> {
1814        let sum = self.sum_all()?;
1815        match self.dtype() {
1816            DType::U8 => Ok(sum.to_scalar::<u8>()? == 0),
1817            DType::U32 => Ok(sum.to_scalar::<u32>()? == 0),
1818            DType::I16 => Ok(sum.to_scalar::<i16>()? == 0),
1819            DType::I32 => Ok(sum.to_scalar::<i32>()? == 0),
1820            DType::I64 => Ok(sum.to_scalar::<i64>()? == 0),
1821            DType::F16 => Ok(sum.to_scalar::<half::f16>()? == half::f16::from_f32_const(0.)),
1822            DType::BF16 => Ok(sum.to_scalar::<half::bf16>()? == half::bf16::from_f32_const(0.)),
1823            DType::F32 => Ok(sum.to_scalar::<f32>()? == 0.),
1824            DType::F64 => Ok(sum.to_scalar::<f64>()? == 0.),
1825            DType::F8E4M3 => Ok(sum.to_scalar::<F8E4M3>()? == F8E4M3::ZERO),
1826        }
1827    }
1828}
1829
1830pub fn clamp_for_f16(xs: &Tensor) -> Result<Tensor> {
1831    let mut max = match xs.dtype() {
1832        DType::U8 => u8::MAX as f32 - 1000.,
1833        DType::U32 => u32::MAX as f32 - 1000.,
1834        DType::I16 => i16::MAX as f32 - 1000.,
1835        DType::I32 => i32::MAX as f32 - 1000.,
1836        DType::I64 => i64::MAX as f32 - 1000.,
1837        DType::F16 => half::f16::MAX.to_f32_const() - 1000.,
1838        DType::BF16 => half::bf16::MAX.to_f32_const() - 1000.,
1839        DType::F32 => f32::MAX - 1000.,
1840        DType::F64 => f64::MAX as f32 - 1000.,
1841        DType::F8E4M3 => F8E4M3::MAX.to_f32() - 1000.,
1842    };
1843    if xs.is_inf()?.any()? {
1844        max -= 1000.;
1845    }
1846    xs.clamp(-max, max)
1847}
1848
1849pub struct FloatInfo {
1850    /// Minimum representable value.
1851    pub min: f64,
1852    /// Maximum representable value.
1853    pub max: f64,
1854    /// The difference between 1.0 and the next smallest representable float larger than 1.0.
1855    pub eps: f64,
1856    pub dtype: DType,
1857}
1858
1859pub trait GetFloatInfo {
1860    fn finfo(&self) -> Result<FloatInfo>;
1861}
1862
1863impl GetFloatInfo for DType {
1864    fn finfo(&self) -> Result<FloatInfo> {
1865        let finfo = match self {
1866            Self::BF16 => FloatInfo {
1867                min: bf16::MIN.to_f64(),
1868                max: bf16::MAX.to_f64(),
1869                eps: bf16::EPSILON.to_f64(),
1870                dtype: DType::BF16,
1871            },
1872            Self::F16 => FloatInfo {
1873                min: f16::MIN.to_f64(),
1874                max: f16::MAX.to_f64(),
1875                eps: f16::EPSILON.to_f64(),
1876                dtype: DType::F16,
1877            },
1878            Self::F32 => FloatInfo {
1879                min: f32::MIN as f64,
1880                max: f32::MAX as f64,
1881                eps: f32::EPSILON as f64,
1882                dtype: DType::F32,
1883            },
1884            Self::F64 => FloatInfo {
1885                min: f64::MIN,
1886                max: f64::MAX,
1887                eps: f64::EPSILON,
1888                dtype: DType::F64,
1889            },
1890            Self::F8E4M3 => FloatInfo {
1891                min: F8E4M3::MIN.to_f64(),
1892                max: F8E4M3::MAX.to_f64(),
1893                eps: F8E4M3::EPSILON.to_f64(),
1894                dtype: DType::F8E4M3,
1895            },
1896            other => {
1897                candle_core::bail!("Expected a float type for `GetFloatInfo`, got {other:?}");
1898            }
1899        };
1900        Ok(finfo)
1901    }
1902}
1903
1904#[derive(Clone)]
1905pub struct Mlp {
1906    pub gate: Arc<dyn QuantMethod>,
1907    pub up: Arc<dyn QuantMethod>,
1908    pub down: Arc<dyn QuantMethod>,
1909    act: Activation,
1910    params: Vec<usize>,
1911}
1912
1913impl Mlp {
1914    pub fn new(
1915        vb: ShardedVarBuilder,
1916        hidden_size: usize,
1917        intermediate_size: usize,
1918        quantization_config: &Option<QuantizedConfig>,
1919        hidden_act: Activation,
1920        comm: &Arc<mistralrs_quant::Comm>,
1921    ) -> Result<Self> {
1922        Ok(Self {
1923            gate: ColumnParallelLayer::new(
1924                hidden_size,
1925                intermediate_size,
1926                quantization_config,
1927                false,
1928                comm,
1929                vb.pp("gate_proj"),
1930            )?,
1931            up: ColumnParallelLayer::new(
1932                hidden_size,
1933                intermediate_size,
1934                quantization_config,
1935                false,
1936                comm,
1937                vb.pp("up_proj"),
1938            )?,
1939            down: RowParallelLayer::new(
1940                intermediate_size,
1941                hidden_size,
1942                quantization_config,
1943                false,
1944                comm,
1945                vb.pp("down_proj"),
1946            )?,
1947            act: hidden_act,
1948            params: vec![hidden_size, intermediate_size],
1949        })
1950    }
1951
1952    pub fn replicate(
1953        params: &[usize],
1954        vb: ShardedVarBuilder,
1955        act: Activation,
1956        comm: &Arc<mistralrs_quant::Comm>,
1957    ) -> Result<Self> {
1958        Self::new(vb, params[0], params[1], &None, act, comm)
1959    }
1960
1961    pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
1962        let original_dtype = xs.dtype();
1963        let mut xs = xs.clone();
1964        if let Some(t) = self.gate.quantized_act_type() {
1965            xs = xs.to_dtype(t)?;
1966        }
1967        let lhs = self.gate.forward(&xs)?;
1968        let rhs = self.up.forward(&xs)?;
1969        let mut res = self.down.forward(&candle_nn::ops::mul_and_act(
1970            &lhs,
1971            &rhs,
1972            self.act.try_into()?,
1973        )?)?;
1974        if self.gate.quantized_act_type().is_some() {
1975            res = res.to_dtype(original_dtype)?;
1976        }
1977        Ok(res)
1978    }
1979}
1980
1981impl AnyMoeTrainableLayer for Mlp {}
1982
1983impl MlpLayer for Mlp {
1984    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
1985        let original_dtype = xs.dtype();
1986        let mut xs = xs.clone();
1987        if let Some(t) = self.gate.quantized_act_type() {
1988            xs = xs.to_dtype(t)?;
1989        }
1990        let lhs = MatMul.qmethod_matmul(&xs, &*self.gate)?;
1991        let rhs = MatMul.qmethod_matmul(&xs, &*self.up)?;
1992        let mut res = if matches!(
1993            self.act,
1994            Activation::Gelu | Activation::Silu | Activation::Relu
1995        ) {
1996            MatMul.qmethod_matmul(
1997                &candle_nn::ops::mul_and_act(&lhs, &rhs, self.act.try_into()?)?,
1998                &*self.down,
1999            )?
2000        } else {
2001            MatMul.qmethod_matmul(&(&lhs.apply(&self.act)? * &rhs)?, &*self.down)?
2002        };
2003        if self.gate.quantized_act_type().is_some() {
2004            res = res.to_dtype(original_dtype)?;
2005        }
2006        Ok(res)
2007    }
2008    fn get_isq_layers(&mut self) -> Vec<&mut Arc<dyn QuantMethod>> {
2009        vec![&mut self.gate, &mut self.up, &mut self.down]
2010    }
2011    fn clone(&self) -> Box<dyn MlpLayer> {
2012        Box::new(Clone::clone(self))
2013    }
2014    fn get_params(&self) -> &[usize] {
2015        &self.params
2016    }
2017    fn hidden_act(&self) -> Activation {
2018        self.act
2019    }
2020    // gate, up, down
2021    fn new_added_delta(&self, deltas: Vec<Option<Tensor>>) -> Result<Box<dyn MlpLayer>> {
2022        let gate = if let Some(ref delta) = deltas[0] {
2023            self.gate.add_delta_w(delta)?
2024        } else {
2025            self.gate.clone()
2026        };
2027        let up = if let Some(ref delta) = deltas[1] {
2028            self.up.add_delta_w(delta)?
2029        } else {
2030            self.up.clone()
2031        };
2032        let down = if let Some(ref delta) = deltas[2] {
2033            self.down.add_delta_w(delta)?
2034        } else {
2035            self.down.clone()
2036        };
2037
2038        Ok(Box::new(Self {
2039            gate,
2040            up,
2041            down,
2042            act: self.act,
2043            params: self.params.clone(),
2044        }))
2045    }
2046
2047    fn dtype_device(&self) -> (DType, Device) {
2048        self.gate.dtype_and_device()
2049    }
2050}
2051
2052pub struct AvgPool2d {
2053    kernel_size: usize,
2054    stride: usize,
2055}
2056
2057impl AvgPool2d {
2058    pub fn new(kernel_size: usize, stride: usize) -> Self {
2059        Self {
2060            kernel_size,
2061            stride,
2062        }
2063    }
2064
2065    pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2066        xs.avg_pool2d_with_stride(self.kernel_size, self.stride)
2067    }
2068}
2069
2070/// Applies 2D reflection padding to a tensor of shape (N, C, H, W).
2071///
2072/// The `padding` argument is a 4-tuple (pad_left, pad_right, pad_top, pad_bottom).
2073/// For left padding, it reflects the values from column 1 up to pad_left (in reverse order);
2074/// for right padding, it reflects from the second-to-last column backwards, and similarly for
2075/// vertical (height) padding.
2076pub struct ReflectionPad2d {
2077    padding: (usize, usize, usize, usize),
2078}
2079
2080impl ReflectionPad2d {
2081    pub fn new(padding: (usize, usize, usize, usize)) -> Self {
2082        Self { padding }
2083    }
2084}
2085
2086impl Module for ReflectionPad2d {
2087    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2088        let (pad_left, pad_right, pad_top, pad_bottom) = self.padding;
2089
2090        let (_n, _c, h, w) = xs.dims4()?;
2091
2092        // --- Horizontal Padding (along width, axis = 3) ---
2093        // For left padding, we reflect columns 1..=pad_left (in reverse order).
2094        let left_pad = if pad_left > 0 {
2095            // Create indices: [pad_left, pad_left-1, ..., 1]
2096            let indices: Vec<i64> = (1..=pad_left as i64).rev().collect();
2097            Some(xs.index_select(&Tensor::new(indices, &Device::Cpu)?, 3)?)
2098        } else {
2099            None
2100        };
2101
2102        // For right padding, we reflect from the right side (excluding the last column).
2103        let right_pad = if pad_right > 0 {
2104            // For pad_right == 2, generate indices: [w-2, w-3, ... , w-1-pad_right]
2105            let start = w as i64 - 2;
2106            let indices: Vec<i64> = (0..pad_right as i64).map(|i| start - i).collect();
2107            Some(xs.index_select(&Tensor::new(indices, &Device::Cpu)?, 3)?)
2108        } else {
2109            None
2110        };
2111
2112        // Concatenate horizontally (along width, dim=3)
2113        let x_padded_width = match (left_pad, right_pad) {
2114            (Some(l), Some(r)) => Tensor::cat(&[l, xs.clone(), r], 3)?,
2115            (Some(l), None) => Tensor::cat(&[l, xs.clone()], 3)?,
2116            (None, Some(r)) => Tensor::cat(&[xs.clone(), r], 3)?,
2117            (None, None) => xs.clone(),
2118        };
2119
2120        // --- Vertical Padding (along height, axis = 2) ---
2121        // For top padding, reflect rows 1..=pad_top (in reverse order)
2122        let top_pad = if pad_top > 0 {
2123            let indices: Vec<i64> = (1..=pad_top as i64).rev().collect();
2124            Some(x_padded_width.index_select(&Tensor::new(indices, &Device::Cpu)?, 2)?)
2125        } else {
2126            None
2127        };
2128
2129        // For bottom padding, reflect from the bottom (excluding the last row)
2130        let bottom_pad = if pad_bottom > 0 {
2131            let start = h as i64 - 2;
2132            let indices: Vec<i64> = (0..pad_bottom as i64).map(|i| start - i).collect();
2133            Some(x_padded_width.index_select(&Tensor::new(indices, &Device::Cpu)?, 2)?)
2134        } else {
2135            None
2136        };
2137
2138        // Concatenate vertically (along height, dim=2)
2139        let x_padded = match (top_pad, bottom_pad) {
2140            (Some(t), Some(b)) => Tensor::cat(&[t, x_padded_width, b], 2)?,
2141            (Some(t), None) => Tensor::cat(&[t, x_padded_width], 2)?,
2142            (None, Some(b)) => Tensor::cat(&[x_padded_width, b], 2)?,
2143            (None, None) => x_padded_width,
2144        };
2145
2146        Ok(x_padded)
2147    }
2148}
2149
2150pub struct ScaledEmbedding {
2151    scale: f64,
2152    embedding: Embedding,
2153}
2154
2155impl ScaledEmbedding {
2156    pub fn new(scale: f64, embedding: Embedding) -> Self {
2157        Self { scale, embedding }
2158    }
2159
2160    pub fn embeddings(&self) -> &Tensor {
2161        self.embedding.embeddings()
2162    }
2163}
2164
2165impl Module for ScaledEmbedding {
2166    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2167        xs.apply(&self.embedding)? * self.scale
2168    }
2169}