mistralrs_core/
layers.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{f32::consts::PI, ops::Mul, str::FromStr, sync::Arc};
4
5use candle_core::{
6    quantized::{QMatMul, QTensor},
7    Context, DType, Device, IndexOp, Result, Tensor, D,
8};
9use candle_nn::{
10    Conv2d, Conv2dConfig, Embedding, GroupNorm, LayerNorm, LayerNormConfig, Linear, Module,
11};
12use float8::F8E4M3;
13use half::{bf16, f16};
14use mistralrs_quant::{
15    AfqLayer, ColumnParallelLayer, QuantMethod, QuantizedConfig, RowParallelLayer,
16    ShardedVarBuilder,
17};
18use serde::{Deserialize, Serialize};
19
20pub use crate::attention::Sdpa;
21pub use crate::layers_masker::CausalMasker;
22pub use crate::layers_utils::repeat_kv;
23use crate::{
24    amoe::{AnyMoeTrainableLayer, MlpLayer},
25    gguf::Content,
26    models::llama,
27    ops::SplitOp,
28    vision_models::{
29        gemma3::config::Gemma3TextConfig,
30        llama4,
31        mllama::{MLlamaRopeScaling, MLlamaRopeType, MLlamaTextConfig},
32        phi4::Phi4MMConfig,
33    },
34};
35
36pub use mistralrs_quant::MatMul;
37
38pub fn embedding(
39    in_size: usize,
40    out_size: usize,
41    vb: ShardedVarBuilder,
42    config: &Option<QuantizedConfig>,
43) -> Result<Embedding> {
44    // AFQ quantized applies quantization to the embeddings.
45    let embeddings = if let Some(QuantizedConfig::Afq { .. }) = config {
46        let afq_layer =
47            AfqLayer::afq_linear_b(out_size, in_size, config.as_ref().unwrap(), false, vb)?;
48        afq_layer.dequantize_w()?
49    } else {
50        vb.get_with_hints((in_size, out_size), "weight", Default::default())?
51    };
52    Ok(Embedding::new(embeddings, out_size))
53}
54
55pub fn layer_norm<C: Into<LayerNormConfig>>(
56    size: usize,
57    config: C,
58    vb: ShardedVarBuilder,
59) -> Result<LayerNorm> {
60    let config = config.into();
61    let weight = vb.get(size, "weight")?;
62    if config.affine {
63        let bias = vb.get(size, "bias")?;
64        Ok(LayerNorm::new(weight, bias, config.eps))
65    } else {
66        Ok(LayerNorm::new_no_bias(weight, config.eps))
67    }
68}
69
70pub fn group_norm(
71    num_groups: usize,
72    num_channels: usize,
73    eps: f64,
74    vb: ShardedVarBuilder,
75) -> Result<GroupNorm> {
76    let weight = vb.get(num_channels, "weight")?;
77    let bias = vb.get(num_channels, "bias")?;
78    GroupNorm::new(weight, bias, num_channels, num_groups, eps)
79}
80
81pub fn conv2d(
82    in_channels: usize,
83    out_channels: usize,
84    kernel_size: usize,
85    cfg: Conv2dConfig,
86    vb: ShardedVarBuilder,
87) -> Result<Conv2d> {
88    let ws = vb.get(
89        (
90            out_channels,
91            in_channels / cfg.groups,
92            kernel_size,
93            kernel_size,
94        ),
95        "weight",
96    )?;
97    let bs = vb.get(out_channels, "bias")?;
98    Ok(Conv2d::new(ws, Some(bs), cfg))
99}
100
101pub fn conv2d_no_bias(
102    in_channels: usize,
103    out_channels: usize,
104    kernel_size: usize,
105    cfg: Conv2dConfig,
106    vb: ShardedVarBuilder,
107) -> Result<Conv2d> {
108    let ws = vb.get(
109        (
110            out_channels,
111            in_channels / cfg.groups,
112            kernel_size,
113            kernel_size,
114        ),
115        "weight",
116    )?;
117    Ok(Conv2d::new(ws, None, cfg))
118}
119
120pub fn linear(in_dim: usize, out_dim: usize, vb: ShardedVarBuilder) -> Result<Linear> {
121    let ws = vb.get((out_dim, in_dim), "weight")?;
122    let bs = vb.get(out_dim, "bias")?;
123    Ok(Linear::new(ws, Some(bs)))
124}
125
126pub fn linear_no_bias(in_dim: usize, out_dim: usize, vb: ShardedVarBuilder) -> Result<Linear> {
127    let ws = vb.get((out_dim, in_dim), "weight")?;
128    Ok(Linear::new(ws, None))
129}
130
131pub fn linear_b(
132    in_dim: usize,
133    out_dim: usize,
134    bias: bool,
135    vb: ShardedVarBuilder,
136) -> Result<Linear> {
137    if bias {
138        linear(in_dim, out_dim, vb)
139    } else {
140        linear_no_bias(in_dim, out_dim, vb)
141    }
142}
143
144#[derive(Debug, Clone)]
145pub struct RmsNorm {
146    eps: f64,
147    weight: Tensor,
148}
149
150impl RmsNorm {
151    pub fn new(size: usize, eps: f64, vb: ShardedVarBuilder) -> Result<Self> {
152        let w = vb.get(size, "weight")?;
153        Ok(Self { eps, weight: w })
154    }
155
156    /// Gemma uses weight + 1.0
157    pub fn new_gemma(size: usize, eps: f64, vb: ShardedVarBuilder) -> Result<Self> {
158        let w = vb.get(size, "weight")?;
159        let w = (w + 1.0)?;
160        Ok(Self { eps, weight: w })
161    }
162
163    /// Gemma uses weight + 1.0. Undo for UQFF generation.
164    pub fn undo_gemma(&self) -> Result<Self> {
165        Ok(Self {
166            eps: self.eps,
167            weight: (&self.weight - 1.0)?,
168        })
169    }
170
171    pub fn from_w(w: Tensor, eps: f64) -> Result<Self> {
172        Ok(Self { eps, weight: w })
173    }
174
175    pub fn weight(&self) -> &Tensor {
176        &self.weight
177    }
178}
179
180impl Module for RmsNorm {
181    fn forward(&self, x: &Tensor) -> Result<Tensor> {
182        candle_nn::ops::rms_norm(&x.contiguous()?, &self.weight, self.eps as f32)
183    }
184}
185
186#[derive(Debug, Clone)]
187pub struct F32RmsNorm {
188    w: Tensor,
189    eps: f64,
190}
191
192impl F32RmsNorm {
193    pub fn new(size: usize, eps: f64, vb: ShardedVarBuilder) -> Result<Self> {
194        Ok(Self {
195            w: vb.get((size,), "weight")?,
196            eps,
197        })
198    }
199
200    pub fn weight(&self) -> &Tensor {
201        &self.w
202    }
203}
204
205impl Module for F32RmsNorm {
206    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
207        let initial_type = xs.dtype();
208        let mut xs = xs.to_dtype(DType::F32)?;
209        let var = xs.powf(2.)?.mean_keepdim(D::Minus1)?;
210        xs = xs.broadcast_mul(&(&var + self.eps)?.recip()?.sqrt()?)?;
211        xs.to_dtype(initial_type)?.broadcast_mul(&self.w)
212    }
213}
214
215#[derive(Debug, Clone)]
216pub struct QRmsNorm {
217    eps: f64,
218    weight: Tensor,
219}
220
221impl QRmsNorm {
222    pub fn new(scale: QTensor, eps: f32) -> Result<Self> {
223        let scale = scale.dequantize(&scale.device())?;
224        Ok(Self {
225            eps: eps as f64,
226            weight: scale,
227        })
228    }
229
230    pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
231        candle_nn::ops::rms_norm(&x.contiguous()?, &self.weight, self.eps as f32)
232    }
233}
234
235/// RoPE supporting LongRope
236#[derive(Debug, Clone)]
237pub struct PhiRotaryEmbedding {
238    short_sin: Tensor,
239    short_cos: Tensor,
240    long_cos: Option<Tensor>,
241    long_sin: Option<Tensor>,
242    original_max_position_embeddings: usize,
243}
244
245#[derive(Debug, Clone, Deserialize, Serialize)]
246#[serde(rename_all = "lowercase")]
247pub enum ScaledRopeType {
248    #[serde(alias = "su")]
249    #[serde(alias = "longrope")]
250    Su,
251    #[serde(alias = "yarn")]
252    Yarn,
253    #[serde(alias = "dynamic")]
254    Dynamic,
255    #[serde(alias = "linear")]
256    Linear,
257}
258
259impl FromStr for ScaledRopeType {
260    type Err = candle_core::Error;
261    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
262        match s {
263            "su" | "longrope" => Ok(Self::Su),
264            "yarn" => Ok(Self::Yarn),
265            "linear" => Ok(Self::Linear),
266            "dynamic" => Ok(Self::Dynamic),
267            _ => Err(candle_core::Error::Msg(
268                "Expected either `su` or `yarn` scaled RoPE type.".to_string(),
269            )),
270        }
271    }
272}
273
274#[derive(Debug, Clone, Deserialize, Serialize)]
275#[serde(untagged)]
276pub enum PhiRopeScalingConfig {
277    Classic {
278        short_factor: Vec<f64>,
279        long_factor: Vec<f64>,
280        #[serde(rename = "type")]
281        scaling_type: ScaledRopeType,
282    },
283    Scaled {
284        short_factor: Vec<f64>,
285        long_factor: Vec<f64>,
286        #[serde(rename = "type")]
287        scaling_type: ScaledRopeType,
288        long_mscale: f64,
289        short_mscale: f64,
290    },
291}
292
293pub struct PhiRopeConfig {
294    pub rope_scaling: Option<PhiRopeScalingConfig>,
295    pub max_position_embeddings: usize,
296    pub original_max_position_embeddings: usize,
297    pub rope_theta: f64,
298    pub head_dim: usize,
299    pub partial_rotary_factor: Option<f64>,
300}
301
302impl PhiRotaryEmbedding {
303    fn new_classic_scaled(
304        short_factor: &[f64],
305        long_factor: &[f64],
306        scaling_type: &ScaledRopeType,
307        cfg: &PhiRopeConfig,
308        dtype: DType,
309        dev: &Device,
310    ) -> Result<Self> {
311        let max_seq_len = cfg.max_position_embeddings;
312        let dim = (cfg.head_dim as f64 * cfg.partial_rotary_factor.unwrap_or(1.)) as usize;
313
314        // Calculate scale
315        let scale =
316            cfg.max_position_embeddings as f64 / cfg.original_max_position_embeddings as f64;
317        let scaling_factor = if scale <= 1.0 {
318            1.0
319        } else {
320            match scaling_type {
321                ScaledRopeType::Su => {
322                    (1.0 + scale.ln() / (cfg.original_max_position_embeddings as f64).ln()).sqrt()
323                }
324                ScaledRopeType::Yarn => 0.1 * scale.ln() + 1.0,
325                _ => candle_core::bail!("Expected either `su` or `yarn` RoPE"),
326            }
327        };
328
329        // Calculate inv freqs for short, long
330        let inv_freq_long = (0..dim)
331            .step_by(2)
332            .enumerate()
333            .map(|(k, i)| {
334                (1f64 / (long_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64))) as f32
335            })
336            .collect::<Vec<_>>();
337        let inv_freq_short = (0..dim)
338            .step_by(2)
339            .enumerate()
340            .map(|(k, i)| {
341                (1f64 / (short_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64))) as f32
342            })
343            .collect::<Vec<_>>();
344        let inv_freq_len = inv_freq_long.len();
345
346        let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
347            .to_dtype(DType::F32)?
348            .reshape((max_seq_len, 1))?;
349
350        // Calculate sin,cos for long
351        let inv_freq_long = Tensor::from_vec(inv_freq_long, (1, inv_freq_len), dev)?;
352        let freqs_long = t.matmul(&inv_freq_long)?;
353        let long_sin = freqs_long.sin()?.mul(scaling_factor)?.to_dtype(dtype)?;
354        let long_cos = freqs_long.cos()?.mul(scaling_factor)?.to_dtype(dtype)?;
355
356        // Calculate sin,cos for short
357        let inv_freq_short =
358            Tensor::from_vec(inv_freq_short, (1, inv_freq_len), dev)?.to_dtype(DType::F32)?;
359        let freqs_short = t.matmul(&inv_freq_short)?;
360        let short_sin = freqs_short.sin()?.mul(scaling_factor)?.to_dtype(dtype)?;
361        let short_cos = freqs_short.cos()?.mul(scaling_factor)?.to_dtype(dtype)?;
362
363        Ok(Self {
364            short_cos,
365            short_sin,
366            long_cos: Some(long_cos),
367            long_sin: Some(long_sin),
368            original_max_position_embeddings: cfg.original_max_position_embeddings,
369        })
370    }
371
372    fn new_unscaled(cfg: &PhiRopeConfig, dtype: DType, dev: &Device) -> Result<Self> {
373        let max_seq_len = cfg.max_position_embeddings;
374        let dim = (cfg.head_dim as f64 * cfg.partial_rotary_factor.unwrap_or(1.)) as usize;
375
376        let inv_freq: Vec<_> = (0..dim)
377            .step_by(2)
378            .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
379            .collect();
380        let inv_freq_len = inv_freq.len();
381        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
382        let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
383            .to_dtype(DType::F32)?
384            .reshape((max_seq_len, 1))?;
385        let freqs = t.matmul(&inv_freq)?;
386        let sin = freqs.sin()?.to_dtype(dtype)?;
387        let cos = freqs.cos()?.to_dtype(dtype)?;
388        Ok(Self {
389            short_cos: cos,
390            short_sin: sin,
391            long_cos: None,
392            long_sin: None,
393            original_max_position_embeddings: cfg.original_max_position_embeddings,
394        })
395    }
396
397    #[allow(clippy::too_many_arguments)]
398    fn new_scaled(
399        short_factor: &[f64],
400        long_factor: &[f64],
401        scaling_type: &ScaledRopeType,
402        long_mscale: f64,
403        short_mscale: f64,
404        cfg: &PhiRopeConfig,
405        dtype: DType,
406        dev: &Device,
407    ) -> Result<Self> {
408        let max_seq_len = cfg.max_position_embeddings;
409        let dim = (cfg.head_dim as f64 * cfg.partial_rotary_factor.unwrap_or(1.)) as usize;
410
411        if !matches!(scaling_type, ScaledRopeType::Su) {
412            candle_core::bail!("Scaled Phi3 RoPE (non-classic scaled, with mscales) must have type `su`/`longrope`.");
413        }
414
415        if short_factor.len() != dim / 2 {
416            candle_core::bail!(
417                "Misaligned length {}, expected {} for `su`/`longrope` short rescale factors",
418                short_factor.len(),
419                dim / 2
420            );
421        }
422        if long_factor.len() != dim / 2 {
423            candle_core::bail!(
424                "Misaligned length {}, expected {} for `su`/`longrope` long rescale factors",
425                long_factor.len(),
426                dim / 2
427            );
428        }
429
430        // Short cos/sin
431        let inv_freq_short: Vec<_> = (0..dim)
432            .step_by(2)
433            .enumerate()
434            .map(|(k, i)| {
435                1f32 / (short_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32
436            })
437            .collect();
438        let inv_freq_len_short = inv_freq_short.len();
439        let inv_freq_short = Tensor::from_vec(inv_freq_short, (1, inv_freq_len_short), dev)?;
440        let t_short = Tensor::arange(0u32, max_seq_len as u32, dev)?
441            .to_dtype(DType::F32)?
442            .reshape((max_seq_len, 1))?;
443        let freqs_short = t_short.matmul(&inv_freq_short)?;
444        let sin_short = (freqs_short.sin()?.to_dtype(dtype)? * short_mscale)?;
445        let cos_short = (freqs_short.cos()?.to_dtype(dtype)? * short_mscale)?;
446
447        // Long cos/sin
448        let inv_freq_long: Vec<_> = (0..dim)
449            .step_by(2)
450            .enumerate()
451            .map(|(k, i)| {
452                1f32 / (long_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32
453            })
454            .collect();
455        let inv_freq_len_long = inv_freq_long.len();
456        let inv_freq_long = Tensor::from_vec(inv_freq_long, (1, inv_freq_len_long), dev)?;
457        let t_long = Tensor::arange(0u32, max_seq_len as u32, dev)?
458            .to_dtype(DType::F32)?
459            .reshape((max_seq_len, 1))?;
460        let freqs_long = t_long.matmul(&inv_freq_long)?;
461        let sin_long = (freqs_long.sin()?.to_dtype(dtype)? * long_mscale)?;
462        let cos_long = (freqs_long.cos()?.to_dtype(dtype)? * long_mscale)?;
463        Ok(Self {
464            short_cos: cos_short,
465            short_sin: sin_short,
466            long_cos: Some(cos_long),
467            long_sin: Some(sin_long),
468            original_max_position_embeddings: cfg.original_max_position_embeddings,
469        })
470    }
471
472    pub fn new(dtype: DType, cfg: impl Into<PhiRopeConfig>, dev: &Device) -> Result<Self> {
473        let cfg: PhiRopeConfig = cfg.into();
474
475        match &cfg.rope_scaling {
476            Some(PhiRopeScalingConfig::Classic {
477                short_factor,
478                long_factor,
479                scaling_type,
480            }) => {
481                Self::new_classic_scaled(short_factor, long_factor, scaling_type, &cfg, dtype, dev)
482            }
483
484            Some(PhiRopeScalingConfig::Scaled {
485                short_factor,
486                long_factor,
487                scaling_type,
488                long_mscale,
489                short_mscale,
490            }) => Self::new_scaled(
491                short_factor,
492                long_factor,
493                scaling_type,
494                *long_mscale,
495                *short_mscale,
496                &cfg,
497                dtype,
498                dev,
499            ),
500
501            None => Self::new_unscaled(&cfg, dtype, dev),
502        }
503    }
504
505    /// Returns (sin, cos) taking into account LongRope
506    fn get_long_or_short_sin_cos(&self, position_ids: &[usize]) -> (&Tensor, &Tensor) {
507        if self.long_cos.is_none() {
508            return (&self.short_sin, &self.short_cos);
509        }
510        let seq_len = position_ids.iter().max().unwrap() + 1;
511        if seq_len > self.original_max_position_embeddings {
512            (
513                self.long_sin.as_ref().unwrap(),
514                self.long_cos.as_ref().unwrap(),
515            )
516        } else {
517            (&self.short_sin, &self.short_cos)
518        }
519    }
520
521    pub fn forward(
522        &self,
523        q: &Tensor,
524        k: &Tensor,
525        seqlen_offsets: &[usize],
526        position_ids: &[usize],
527    ) -> Result<(Tensor, Tensor)> {
528        let (sin, cos) = self.get_long_or_short_sin_cos(position_ids);
529        let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
530
531        let rot_dim = cos.dim(D::Minus1)? * 2;
532
533        // Case for Phi 3 / Phi 4 mini
534        if rot_dim != q.dim(D::Minus1)? {
535            let rot_dim = cos.dim(D::Minus1)? * 2;
536            let q_rot = q.narrow(D::Minus1, 0, rot_dim)?;
537            let q_pass = q.narrow(D::Minus1, rot_dim, q.dim(D::Minus1)? - rot_dim)?;
538            let k_rot = k.narrow(D::Minus1, 0, rot_dim)?;
539            let k_pass = k.narrow(D::Minus1, rot_dim, k.dim(D::Minus1)? - rot_dim)?;
540
541            let (q_rot, k_rot) = if seqlen_offsets.len() == 1 {
542                let cos = cos.narrow(0, seqlen_offsets[0], seq_len)?;
543                let sin = sin.narrow(0, seqlen_offsets[0], seq_len)?;
544                let q_embed = candle_nn::rotary_emb::rope(&q_rot.contiguous()?, &cos, &sin)?;
545                let k_embed = candle_nn::rotary_emb::rope(&k_rot.contiguous()?, &cos, &sin)?;
546                (q_embed, k_embed)
547            } else {
548                let mut q_embeds = Vec::new();
549                let mut k_embeds = Vec::new();
550                for (i, offset) in seqlen_offsets.iter().enumerate() {
551                    let cos = cos.narrow(0, *offset, seq_len)?;
552                    let sin = sin.narrow(0, *offset, seq_len)?;
553                    let q_embed = candle_nn::rotary_emb::rope(
554                        &q_rot.i(i)?.unsqueeze(0)?.contiguous()?,
555                        &cos,
556                        &sin,
557                    )?;
558                    let k_embed = candle_nn::rotary_emb::rope(
559                        &k_rot.i(i)?.unsqueeze(0)?.contiguous()?,
560                        &cos,
561                        &sin,
562                    )?;
563                    q_embeds.push(q_embed);
564                    k_embeds.push(k_embed);
565                }
566                let q_rot = Tensor::cat(&q_embeds, 0)?;
567                let k_rot = Tensor::cat(&k_embeds, 0)?;
568                (q_rot, k_rot)
569            };
570
571            Ok((
572                Tensor::cat(&[q_rot, q_pass], D::Minus1)?.contiguous()?,
573                Tensor::cat(&[k_rot, k_pass], D::Minus1)?.contiguous()?,
574            ))
575        } else if seqlen_offsets.len() == 1 {
576            let cos = cos.narrow(0, seqlen_offsets[0], seq_len)?;
577            let sin = sin.narrow(0, seqlen_offsets[0], seq_len)?;
578            let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
579            let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
580            Ok((q_embed, k_embed))
581        } else {
582            let mut q_embeds = Vec::new();
583            let mut k_embeds = Vec::new();
584            for (i, offset) in seqlen_offsets.iter().enumerate() {
585                let cos = cos.narrow(0, *offset, seq_len)?;
586                let sin = sin.narrow(0, *offset, seq_len)?;
587                let q_embed =
588                    candle_nn::rotary_emb::rope(&q.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
589                let k_embed =
590                    candle_nn::rotary_emb::rope(&k.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
591                q_embeds.push(q_embed);
592                k_embeds.push(k_embed);
593            }
594            Ok((Tensor::cat(&q_embeds, 0)?, Tensor::cat(&k_embeds, 0)?))
595        }
596    }
597}
598
599/// RoPE for Llama3
600#[derive(Debug, Clone)]
601pub struct Llama3RotaryEmbedding(RotaryEmbedding);
602
603#[derive(Debug, Clone, Deserialize, Serialize, Default)]
604pub enum Llama3RopeType {
605    #[serde(rename = "llama3")]
606    Llama3,
607    #[default]
608    #[serde(rename = "default")]
609    Default,
610}
611
612#[derive(Debug, Clone, Deserialize, Serialize, Default)]
613pub struct Llama3RopeConfig {
614    pub factor: f32,
615    pub low_freq_factor: f32,
616    pub high_freq_factor: f32,
617    pub original_max_position_embeddings: usize,
618    pub rope_type: Llama3RopeType,
619}
620
621fn calculate_default_inv_freq(cfg: &llama::Config) -> Vec<f32> {
622    let head_dim = cfg.hidden_size / cfg.num_attention_heads;
623    (0..head_dim)
624        .step_by(2)
625        .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
626        .collect()
627}
628
629fn calculate_default_inv_freq_llama4(cfg: &llama4::TextConfig) -> Vec<f32> {
630    let head_dim = cfg.hidden_size / cfg.num_attention_heads;
631    (0..head_dim)
632        .step_by(2)
633        .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
634        .collect()
635}
636
637// https://github.com/huggingface/transformers/blob/1392a6867f40a55dfabaf306745c67627598b1af/src/transformers/modeling_rope_utils.py#L298
638impl Llama3RotaryEmbedding {
639    pub fn new_llama3(
640        dtype: DType,
641        cfg: &llama::Config,
642        dev: &Device,
643        is_gpt_neox: bool,
644    ) -> Result<Self> {
645        match &cfg.rope_scaling {
646            None
647            | Some(Llama3RopeConfig {
648                rope_type: Llama3RopeType::Default,
649                ..
650            }) => Ok(Self(RotaryEmbedding::new(
651                cfg.rope_theta,
652                cfg.hidden_size / cfg.num_attention_heads,
653                cfg.max_position_embeddings,
654                dev,
655                is_gpt_neox,
656                dtype,
657            )?)),
658            Some(rope_scaling) => {
659                let low_freq_wavelen = rope_scaling.original_max_position_embeddings as f32
660                    / rope_scaling.low_freq_factor;
661                let high_freq_wavelen = rope_scaling.original_max_position_embeddings as f32
662                    / rope_scaling.high_freq_factor;
663
664                let inv_freq = calculate_default_inv_freq(cfg)
665                    .into_iter()
666                    .map(|freq| {
667                        let wavelen = 2. * PI / freq;
668                        if wavelen < high_freq_wavelen {
669                            freq
670                        } else if wavelen > low_freq_wavelen {
671                            freq / rope_scaling.factor
672                        } else {
673                            let smooth = (rope_scaling.original_max_position_embeddings as f32
674                                / wavelen
675                                - rope_scaling.low_freq_factor)
676                                / (rope_scaling.high_freq_factor - rope_scaling.low_freq_factor);
677                            (1. - smooth) * freq / rope_scaling.factor + smooth * freq
678                        }
679                    })
680                    .collect::<Vec<_>>();
681                let inv_freq_len = inv_freq.len();
682                let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
683
684                let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
685                    .to_dtype(DType::F32)?
686                    .reshape((cfg.max_position_embeddings, 1))?;
687                let freqs = t.matmul(&inv_freq)?;
688                let sin = freqs.sin()?.to_dtype(dtype)?;
689                let cos = freqs.cos()?.to_dtype(dtype)?;
690                Ok(Self(RotaryEmbedding {
691                    sin,
692                    cos,
693                    is_gpt_neox,
694                }))
695            }
696        }
697    }
698
699    pub fn new_llama4(
700        dtype: DType,
701        cfg: &llama4::TextConfig,
702        dev: &Device,
703        is_gpt_neox: bool,
704    ) -> Result<Self> {
705        match &cfg.rope_scaling {
706            None
707            | Some(Llama3RopeConfig {
708                rope_type: Llama3RopeType::Default,
709                ..
710            }) => Ok(Self(RotaryEmbedding::new(
711                cfg.rope_theta,
712                cfg.hidden_size / cfg.num_attention_heads,
713                cfg.max_position_embeddings,
714                dev,
715                is_gpt_neox,
716                dtype,
717            )?)),
718            Some(rope_scaling) => {
719                let low_freq_wavelen = rope_scaling.original_max_position_embeddings as f32
720                    / rope_scaling.low_freq_factor;
721                let high_freq_wavelen = rope_scaling.original_max_position_embeddings as f32
722                    / rope_scaling.high_freq_factor;
723
724                let inv_freq = calculate_default_inv_freq_llama4(cfg)
725                    .into_iter()
726                    .map(|freq| {
727                        let wavelen = 2. * PI / freq;
728                        if wavelen < high_freq_wavelen {
729                            freq
730                        } else if wavelen > low_freq_wavelen {
731                            freq / rope_scaling.factor
732                        } else {
733                            let smooth = (rope_scaling.original_max_position_embeddings as f32
734                                / wavelen
735                                - rope_scaling.low_freq_factor)
736                                / (rope_scaling.high_freq_factor - rope_scaling.low_freq_factor);
737                            (1. - smooth) * freq / rope_scaling.factor + smooth * freq
738                        }
739                    })
740                    .collect::<Vec<_>>();
741                let inv_freq_len = inv_freq.len();
742                let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
743
744                let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
745                    .to_dtype(DType::F32)?
746                    .reshape((cfg.max_position_embeddings, 1))?;
747                let freqs = t.matmul(&inv_freq)?;
748                let sin = freqs.sin()?.to_dtype(dtype)?;
749                let cos = freqs.cos()?.to_dtype(dtype)?;
750                Ok(Self(RotaryEmbedding {
751                    sin,
752                    cos,
753                    is_gpt_neox,
754                }))
755            }
756        }
757    }
758
759    pub fn new_mllama3(
760        dtype: DType,
761        cfg: &MLlamaTextConfig,
762        dev: &Device,
763        is_gpt_neox: bool,
764    ) -> Result<Self> {
765        match &cfg.rope_scaling {
766            None
767            | Some(MLlamaRopeScaling {
768                rope_type: MLlamaRopeType::Default,
769                ..
770            }) => Ok(Self(RotaryEmbedding::new(
771                cfg.rope_theta,
772                cfg.hidden_size / cfg.num_attention_heads,
773                cfg.max_position_embeddings,
774                dev,
775                is_gpt_neox,
776                dtype,
777            )?)),
778            Some(MLlamaRopeScaling {
779                rope_type: MLlamaRopeType::Llama3,
780                original_max_position_embeddings,
781                factor,
782                attention_factor: _,
783                beta_fast: _,
784                beta_slow: _,
785                short_factor: _,
786                long_factor: _,
787                low_freq_factor,
788                high_freq_factor,
789            }) => {
790                let factor = factor.context("MLlama Llama3 RoPE needs `factor` parameter.")?;
791                let low_freq_factor = low_freq_factor
792                    .context("MLlama Llama3 RoPE needs `low_freq_factor` parameter.")?;
793                let high_freq_factor = high_freq_factor
794                    .context("MLlama Llama3 RoPE needs `high_freq_factor` parameter.")?;
795
796                let low_freq_wavelen = *original_max_position_embeddings as f32 / low_freq_factor;
797                let high_freq_wavelen = *original_max_position_embeddings as f32 / high_freq_factor;
798
799                let head_dim = cfg.hidden_size / cfg.num_attention_heads;
800
801                let inv_freq = (0..head_dim)
802                    .step_by(2)
803                    .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
804                    .map(|freq| {
805                        let wavelen = 2. * PI / freq;
806                        if wavelen < high_freq_wavelen {
807                            freq
808                        } else if wavelen > low_freq_wavelen {
809                            freq / factor
810                        } else {
811                            let smooth = (*original_max_position_embeddings as f32 / wavelen
812                                - low_freq_factor)
813                                / (high_freq_factor - low_freq_factor);
814                            (1. - smooth) * freq / factor + smooth * freq
815                        }
816                    })
817                    .collect::<Vec<_>>();
818                let inv_freq_len = inv_freq.len();
819                let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
820
821                let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
822                    .to_dtype(DType::F32)?
823                    .reshape((cfg.max_position_embeddings, 1))?;
824                let freqs = t.matmul(&inv_freq)?;
825                let sin = freqs.sin()?.to_dtype(dtype)?;
826                let cos = freqs.cos()?.to_dtype(dtype)?;
827                Ok(Self(RotaryEmbedding {
828                    sin,
829                    cos,
830                    is_gpt_neox,
831                }))
832            }
833            Some(MLlamaRopeScaling {
834                rope_type: other, ..
835            }) => {
836                candle_core::bail!(
837                    "MLlama doesn't support any other RoPE type than `llama3`, got {other:?}"
838                )
839            }
840        }
841    }
842
843    pub fn forward(
844        &self,
845        q: &Tensor,
846        k: &Tensor,
847        seqlen_offsets: &[usize],
848    ) -> Result<(Tensor, Tensor)> {
849        self.0.forward(q, k, seqlen_offsets)
850    }
851}
852
853// https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L107
854#[derive(Debug, Clone)]
855pub struct Qwen2VLRotaryEmbedding {
856    inv_freq: Tensor,
857    mrope_section: Vec<usize>,
858}
859
860impl Qwen2VLRotaryEmbedding {
861    pub fn new(
862        base: f32,
863        head_dim: usize,
864        device: &Device,
865        mrope_section: Vec<usize>,
866    ) -> Result<Self> {
867        let inv_freq: Vec<_> = (0..head_dim)
868            .step_by(2)
869            .map(|i| 1f32 / base.powf(i as f32 / head_dim as f32))
870            .collect();
871        let inv_freq_len = inv_freq.len();
872        let inv_freq = Tensor::from_vec(inv_freq, (inv_freq_len,), device)?.to_dtype(DType::F32)?;
873        Ok(Self {
874            inv_freq,
875            mrope_section,
876        })
877    }
878
879    /// (cos, sin)
880    pub fn compute_cos_sin(&self, position_ids: &Tensor, dtype: DType) -> Result<(Tensor, Tensor)> {
881        let inv_freq_expanded =
882            self.inv_freq
883                .reshape((1, 1, (), 1))?
884                .repeat((3, position_ids.dim(1)?, 1, 1))?;
885        let position_ids_expanded = position_ids.unsqueeze(2)?;
886        let freqs = inv_freq_expanded
887            .matmul(&position_ids_expanded.to_dtype(inv_freq_expanded.dtype())?)?
888            .transpose(2, 3)?;
889        let cos = freqs.cos()?;
890        let sin = freqs.sin()?;
891
892        let cos = Tensor::cat(
893            &cos.split(&self.mrope_section, D::Minus1)?
894                .into_iter()
895                .enumerate()
896                .map(|(i, m)| m.i(i % 3))
897                .collect::<Result<Vec<_>>>()?,
898            D::Minus1,
899        )?
900        .squeeze(0)?
901        .to_dtype(dtype)?
902        .contiguous()?;
903        let sin = Tensor::cat(
904            &sin.split(&self.mrope_section, D::Minus1)?
905                .into_iter()
906                .enumerate()
907                .map(|(i, m)| m.i(i % 3))
908                .collect::<Result<Vec<_>>>()?,
909            D::Minus1,
910        )?
911        .squeeze(0)?
912        .to_dtype(dtype)?
913        .contiguous()?;
914
915        Ok((cos, sin))
916    }
917
918    // https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L203
919    pub fn forward(
920        &self,
921        (cos, sin): &(Tensor, Tensor),
922        q: &mut Tensor,
923        k: &mut Tensor,
924    ) -> Result<()> {
925        *q = candle_nn::rotary_emb::rope(&q.contiguous()?, cos, sin)?;
926        *k = candle_nn::rotary_emb::rope(&k.contiguous()?, cos, sin)?;
927        Ok(())
928    }
929}
930
931#[derive(Debug, Clone)]
932pub struct Qwen2_5VLRotaryEmbedding {
933    inv_freq: Tensor,
934    mrope_section: Vec<usize>,
935}
936
937impl Qwen2_5VLRotaryEmbedding {
938    pub fn new(
939        base: f32,
940        head_dim: usize,
941        device: &Device,
942        mrope_section: Vec<usize>,
943    ) -> Result<Self> {
944        let inv_freq: Vec<_> = (0..head_dim)
945            .step_by(2)
946            .map(|i| 1f32 / base.powf(i as f32 / head_dim as f32))
947            .collect();
948        let inv_freq_len = inv_freq.len();
949        let inv_freq = Tensor::from_vec(inv_freq, (inv_freq_len,), device)?.to_dtype(DType::F32)?;
950        Ok(Self {
951            inv_freq,
952            mrope_section,
953        })
954    }
955
956    /// (cos, sin)
957    pub fn compute_cos_sin(&self, position_ids: &Tensor, dtype: DType) -> Result<(Tensor, Tensor)> {
958        let inv_freq_expanded =
959            self.inv_freq
960                .reshape((1, 1, (), 1))?
961                .repeat((3, position_ids.dim(1)?, 1, 1))?;
962        let position_ids_expanded = position_ids.unsqueeze(2)?;
963        let freqs = inv_freq_expanded
964            .matmul(&position_ids_expanded.to_dtype(inv_freq_expanded.dtype())?)?
965            .transpose(2, 3)?;
966        let cos = freqs.cos()?;
967        let sin = freqs.sin()?;
968
969        let cos = Tensor::cat(
970            &cos.split(&self.mrope_section, D::Minus1)?
971                .into_iter()
972                .enumerate()
973                .map(|(i, m)| m.i(i % 3))
974                .collect::<Result<Vec<_>>>()?,
975            D::Minus1,
976        )?
977        .squeeze(0)?
978        .to_dtype(dtype)?
979        .contiguous()?;
980        let sin = Tensor::cat(
981            &sin.split(&self.mrope_section, D::Minus1)?
982                .into_iter()
983                .enumerate()
984                .map(|(i, m)| m.i(i % 3))
985                .collect::<Result<Vec<_>>>()?,
986            D::Minus1,
987        )?
988        .squeeze(0)?
989        .to_dtype(dtype)?
990        .contiguous()?;
991
992        Ok((cos, sin))
993    }
994
995    pub fn forward(
996        &self,
997        (cos, sin): &(Tensor, Tensor),
998        q: &mut Tensor,
999        k: &mut Tensor,
1000    ) -> Result<()> {
1001        *q = candle_nn::rotary_emb::rope(&q.contiguous()?, cos, sin)?;
1002        *k = candle_nn::rotary_emb::rope(&k.contiguous()?, cos, sin)?;
1003        Ok(())
1004    }
1005}
1006
1007#[derive(Debug, Clone)]
1008pub struct DeepSeekV2RotaryEmbedding {
1009    sin: Tensor,
1010    cos: Tensor,
1011}
1012
1013#[derive(Debug, Clone, Deserialize, Serialize)]
1014#[serde(untagged)]
1015pub enum DeepSeekV2RopeScaling {
1016    Yarn {
1017        original_max_position_embeddings: usize,
1018        beta_fast: f32,
1019        beta_slow: f32,
1020        mscale: f32,
1021        mscale_all_dim: f32,
1022        factor: f32,
1023        #[serde(rename = "type")]
1024        scaling_type: ScaledRopeType,
1025    },
1026    LinearOrDynamic {
1027        #[serde(rename = "type")]
1028        scaling_type: ScaledRopeType,
1029        factor: f64,
1030    },
1031}
1032
1033pub struct DeepSeekV2RopeConfig {
1034    pub rope_scaling: Option<DeepSeekV2RopeScaling>,
1035    pub max_position_embeddings: usize,
1036    pub rope_theta: f32,
1037    pub qk_rope_head_dim: usize,
1038}
1039
1040impl DeepSeekV2RotaryEmbedding {
1041    fn new_unscaled(cfg: &DeepSeekV2RopeConfig, dtype: DType, dev: &Device) -> Result<Self> {
1042        let max_seq_len = cfg.max_position_embeddings;
1043        let dim = cfg.qk_rope_head_dim;
1044
1045        let inv_freq: Vec<_> = (0..dim)
1046            .step_by(2)
1047            .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / dim as f32))
1048            .collect();
1049        let inv_freq_len = inv_freq.len();
1050        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
1051        let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
1052            .to_dtype(DType::F32)?
1053            .reshape((max_seq_len, 1))?;
1054        let freqs = t.matmul(&inv_freq)?;
1055
1056        let sin = freqs.sin()?.to_dtype(dtype)?;
1057        let cos = freqs.cos()?.to_dtype(dtype)?;
1058
1059        Ok(Self { sin, cos })
1060    }
1061
1062    fn yarn_find_correction_dim(
1063        num_rot: f32,
1064        dim: usize,
1065        base: f32,
1066        max_position_embeddings: usize,
1067    ) -> f32 {
1068        (dim as f32 * (max_position_embeddings as f32 / (num_rot * 2. * PI)).ln())
1069            / (2. * base.ln())
1070    }
1071
1072    fn yarn_find_correction_range(
1073        low_rot: f32,
1074        high_rot: f32,
1075        dim: usize,
1076        base: f32,
1077        max_position_embeddings: usize,
1078    ) -> (f32, f32) {
1079        let low =
1080            Self::yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings).floor();
1081        let high =
1082            Self::yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings).ceil();
1083        (low.max(0.), high.min(dim as f32 - 1.))
1084    }
1085
1086    fn yarn_linear_ramp_mask(min: f32, mut max: f32, dim: usize, dev: &Device) -> Result<Tensor> {
1087        if min == max {
1088            // https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite/blob/604d5664dddd88a0433dbae533b7fe9472482de0/modeling_deepseek.py#L255
1089            max += 0.001;
1090        }
1091        let linear_func =
1092            ((Tensor::arange(0f32, dim as f32, dev)? - min as f64)? / (max as f64 - min as f64))?;
1093        linear_func.clamp(0., 1)
1094    }
1095
1096    pub(crate) fn yarn_get_mscale(scale: f32, mscale: f32) -> f32 {
1097        if scale <= 1. {
1098            return 1.;
1099        }
1100        0.1 * mscale * scale.ln() + 1.
1101    }
1102
1103    #[allow(clippy::too_many_arguments)]
1104    fn new_yarn(
1105        cfg: &DeepSeekV2RopeConfig,
1106        dtype: DType,
1107        dev: &Device,
1108        original_max_position_embeddings: usize,
1109        beta_fast: f32,
1110        beta_slow: f32,
1111        factor: f32,
1112        mscale: f32,
1113        mscale_all_dim: f32,
1114    ) -> Result<Self> {
1115        let freq_extra: Vec<_> = (0..cfg.qk_rope_head_dim)
1116            .step_by(2)
1117            .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / cfg.qk_rope_head_dim as f32))
1118            .collect();
1119        let freq_extra_len = freq_extra.len();
1120        let freq_extra = Tensor::from_vec(freq_extra, freq_extra_len, dev)?;
1121        let freq_inter: Vec<_> = (0..cfg.qk_rope_head_dim)
1122            .step_by(2)
1123            .map(|i| 1f32 / (factor * cfg.rope_theta.powf(i as f32 / cfg.qk_rope_head_dim as f32)))
1124            .collect();
1125        let freq_inter_len = freq_inter.len();
1126        let freq_inter = Tensor::from_vec(freq_inter, (1, freq_inter_len), dev)?;
1127
1128        let (low, high) = Self::yarn_find_correction_range(
1129            beta_fast,
1130            beta_slow,
1131            cfg.qk_rope_head_dim,
1132            cfg.rope_theta,
1133            original_max_position_embeddings,
1134        );
1135        let inv_freq_mask =
1136            (1. - Self::yarn_linear_ramp_mask(low, high, cfg.qk_rope_head_dim / 2, dev)?)?;
1137        let inv_freq = freq_inter
1138            .broadcast_mul(&(1. - &inv_freq_mask)?)?
1139            .broadcast_add(&freq_extra.broadcast_mul(&inv_freq_mask)?)?;
1140
1141        let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
1142            .to_dtype(DType::F32)?
1143            .reshape((cfg.max_position_embeddings, 1))?;
1144        let freqs = t.matmul(&inv_freq)?;
1145
1146        let mscale =
1147            Self::yarn_get_mscale(factor, mscale) / Self::yarn_get_mscale(factor, mscale_all_dim);
1148        let sin = (freqs.sin()? * mscale as f64)?.to_dtype(dtype)?;
1149        let cos = (freqs.cos()? * mscale as f64)?.to_dtype(dtype)?;
1150
1151        Ok(Self { sin, cos })
1152    }
1153
1154    pub fn new(cfg: &DeepSeekV2RopeConfig, dtype: DType, dev: &Device) -> Result<Self> {
1155        match &cfg.rope_scaling {
1156            Some(DeepSeekV2RopeScaling::LinearOrDynamic {
1157                scaling_type: _,
1158                factor: _,
1159            }) => candle_core::bail!("linear and dynamic rope are not implemented yet!"),
1160            Some(DeepSeekV2RopeScaling::Yarn {
1161                original_max_position_embeddings,
1162                beta_fast,
1163                beta_slow,
1164                factor,
1165                mscale,
1166                mscale_all_dim,
1167                scaling_type: _,
1168            }) => Self::new_yarn(
1169                cfg,
1170                dtype,
1171                dev,
1172                *original_max_position_embeddings,
1173                *beta_fast,
1174                *beta_slow,
1175                *factor,
1176                *mscale,
1177                *mscale_all_dim,
1178            ),
1179            None => Self::new_unscaled(cfg, dtype, dev),
1180        }
1181    }
1182
1183    pub fn forward(
1184        &self,
1185        q: &Tensor,
1186        k: &Tensor,
1187        seqlen_offsets: &[usize],
1188    ) -> Result<(Tensor, Tensor)> {
1189        let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
1190
1191        if seqlen_offsets.len() == 1 {
1192            let cos = self.cos.narrow(0, seqlen_offsets[0], seq_len)?;
1193            let sin = self.sin.narrow(0, seqlen_offsets[0], seq_len)?;
1194            let q_embed = candle_nn::rotary_emb::rope_i(&q.contiguous()?, &cos, &sin)?;
1195            let k_embed = candle_nn::rotary_emb::rope_i(&k.contiguous()?, &cos, &sin)?;
1196            Ok((q_embed, k_embed))
1197        } else {
1198            let mut q_embeds = Vec::new();
1199            let mut k_embeds = Vec::new();
1200            for (i, offset) in seqlen_offsets.iter().enumerate() {
1201                let cos = self.cos.narrow(0, *offset, seq_len)?;
1202                let sin = self.sin.narrow(0, *offset, seq_len)?;
1203                let q_embed = candle_nn::rotary_emb::rope_i(
1204                    &q.i(i)?.unsqueeze(0)?.contiguous()?,
1205                    &cos,
1206                    &sin,
1207                )?;
1208                let k_embed = candle_nn::rotary_emb::rope_i(
1209                    &k.i(i)?.unsqueeze(0)?.contiguous()?,
1210                    &cos,
1211                    &sin,
1212                )?;
1213                q_embeds.push(q_embed);
1214                k_embeds.push(k_embed);
1215            }
1216            Ok((Tensor::cat(&q_embeds, 0)?, Tensor::cat(&k_embeds, 0)?))
1217        }
1218    }
1219}
1220
1221#[derive(Debug, Clone)]
1222pub struct Phi4MMRotaryEmbedding {
1223    short_sin: Tensor,
1224    short_cos: Tensor,
1225    long_cos: Option<Tensor>,
1226    long_sin: Option<Tensor>,
1227    original_max_position_embeddings: usize,
1228}
1229
1230#[derive(Debug, Clone, Default, Deserialize, Serialize)]
1231#[serde(rename_all = "lowercase")]
1232pub enum Phi4MMScaledRopeType {
1233    #[serde(alias = "longrope")]
1234    LongRope,
1235    #[default]
1236    Default,
1237}
1238
1239#[derive(Debug, Clone, Deserialize, Serialize)]
1240pub struct Phi4MMRopeScalingConfig {
1241    short_factor: Option<Vec<f64>>,
1242    long_factor: Option<Vec<f64>>,
1243    #[serde(rename = "type")]
1244    scaling_type: Phi4MMScaledRopeType,
1245}
1246
1247impl Phi4MMRotaryEmbedding {
1248    fn new_unscaled(cfg: &Phi4MMConfig, dtype: DType, dev: &Device) -> Result<Self> {
1249        let max_seq_len = cfg.max_position_embeddings;
1250        let dim = (cfg.head_dim() as f64 * cfg.partial_rotary_factor) as usize;
1251
1252        let inv_freq: Vec<_> = (0..dim)
1253            .step_by(2)
1254            .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
1255            .collect();
1256        let inv_freq_len = inv_freq.len();
1257        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
1258        let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
1259            .to_dtype(DType::F32)?
1260            .reshape((max_seq_len, 1))?;
1261        let freqs = t.matmul(&inv_freq)?;
1262        let sin = freqs.sin()?.to_dtype(dtype)?;
1263        let cos = freqs.cos()?.to_dtype(dtype)?;
1264        Ok(Self {
1265            short_cos: cos,
1266            short_sin: sin,
1267            long_cos: None,
1268            long_sin: None,
1269            original_max_position_embeddings: cfg.original_max_position_embeddings,
1270        })
1271    }
1272
1273    #[allow(clippy::too_many_arguments)]
1274    fn new_longrope(
1275        short_factor: &[f64],
1276        long_factor: &[f64],
1277        cfg: &Phi4MMConfig,
1278        dtype: DType,
1279        dev: &Device,
1280    ) -> Result<Self> {
1281        let max_seq_len = cfg.max_position_embeddings;
1282        let dim = (cfg.head_dim() as f64 * cfg.partial_rotary_factor) as usize;
1283
1284        // Calculate scale
1285        let scale =
1286            cfg.max_position_embeddings as f64 / cfg.original_max_position_embeddings as f64;
1287        let scaling_factor = if scale <= 1.0 {
1288            1.0
1289        } else {
1290            (1.0 + scale.ln() / (cfg.original_max_position_embeddings as f64).ln()).sqrt()
1291        };
1292
1293        // Short cos/sin
1294        let inv_freq_short: Vec<_> = (0..dim)
1295            .step_by(2)
1296            .enumerate()
1297            .map(|(k, i)| {
1298                1f32 / (short_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32
1299            })
1300            .collect();
1301        let inv_freq_len_short = inv_freq_short.len();
1302        let inv_freq_short = Tensor::from_vec(inv_freq_short, (1, inv_freq_len_short), dev)?;
1303        let t_short = Tensor::arange(0u32, max_seq_len as u32, dev)?
1304            .to_dtype(DType::F32)?
1305            .reshape((max_seq_len, 1))?;
1306        let freqs_short = t_short.matmul(&inv_freq_short)?;
1307        let sin_short = (freqs_short.sin()?.to_dtype(dtype)? * scaling_factor)?;
1308        let cos_short = (freqs_short.cos()?.to_dtype(dtype)? * scaling_factor)?;
1309
1310        // Long cos/sin
1311        let inv_freq_long: Vec<_> = (0..dim)
1312            .step_by(2)
1313            .enumerate()
1314            .map(|(k, i)| {
1315                1f32 / (long_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32
1316            })
1317            .collect();
1318        let inv_freq_len_long = inv_freq_long.len();
1319        let inv_freq_long = Tensor::from_vec(inv_freq_long, (1, inv_freq_len_long), dev)?;
1320        let t_long = Tensor::arange(0u32, max_seq_len as u32, dev)?
1321            .to_dtype(DType::F32)?
1322            .reshape((max_seq_len, 1))?;
1323        let freqs_long = t_long.matmul(&inv_freq_long)?;
1324        let sin_long = (freqs_long.sin()?.to_dtype(dtype)? * scaling_factor)?;
1325        let cos_long = (freqs_long.cos()?.to_dtype(dtype)? * scaling_factor)?;
1326
1327        Ok(Self {
1328            short_cos: cos_short,
1329            short_sin: sin_short,
1330            long_cos: Some(cos_long),
1331            long_sin: Some(sin_long),
1332            original_max_position_embeddings: cfg.original_max_position_embeddings,
1333        })
1334    }
1335
1336    pub fn new(dtype: DType, cfg: &Phi4MMConfig, dev: &Device) -> Result<Self> {
1337        match &cfg.rope_scaling {
1338            Some(Phi4MMRopeScalingConfig {
1339                scaling_type: Phi4MMScaledRopeType::LongRope,
1340                short_factor: Some(short_factor),
1341                long_factor: Some(long_factor),
1342            }) => Self::new_longrope(short_factor, long_factor, cfg, dtype, dev),
1343
1344            _ => Self::new_unscaled(cfg, dtype, dev),
1345        }
1346    }
1347
1348    /// Returns (sin, cos) taking into account LongRope
1349    fn get_long_or_short_sin_cos(&self, position_ids: &[usize]) -> (&Tensor, &Tensor) {
1350        if self.long_cos.is_none() {
1351            return (&self.short_sin, &self.short_cos);
1352        }
1353        let seq_len = position_ids.iter().max().unwrap() + 1;
1354        if seq_len > self.original_max_position_embeddings {
1355            (
1356                self.long_sin.as_ref().unwrap(),
1357                self.long_cos.as_ref().unwrap(),
1358            )
1359        } else {
1360            (&self.short_sin, &self.short_cos)
1361        }
1362    }
1363
1364    pub fn forward(
1365        &self,
1366        q: &Tensor,
1367        k: &Tensor,
1368        seqlen_offsets: &[usize],
1369        position_ids: &[usize],
1370    ) -> Result<(Tensor, Tensor)> {
1371        let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
1372        let (sin, cos) = self.get_long_or_short_sin_cos(position_ids);
1373
1374        let rot_dim = cos.dim(D::Minus1)? * 2;
1375        let q_rot = q.narrow(D::Minus1, 0, rot_dim)?;
1376        let q_pass = q.narrow(D::Minus1, rot_dim, q.dim(D::Minus1)? - rot_dim)?;
1377        let k_rot = k.narrow(D::Minus1, 0, rot_dim)?;
1378        let k_pass = k.narrow(D::Minus1, rot_dim, k.dim(D::Minus1)? - rot_dim)?;
1379
1380        let (q_rot, k_rot) = if seqlen_offsets.len() == 1 {
1381            let cos = cos.narrow(0, seqlen_offsets[0], seq_len)?;
1382            let sin = sin.narrow(0, seqlen_offsets[0], seq_len)?;
1383            let q_embed = candle_nn::rotary_emb::rope(&q_rot.contiguous()?, &cos, &sin)?;
1384            let k_embed = candle_nn::rotary_emb::rope(&k_rot.contiguous()?, &cos, &sin)?;
1385            (q_embed, k_embed)
1386        } else {
1387            let mut q_embeds = Vec::new();
1388            let mut k_embeds = Vec::new();
1389            for (i, offset) in seqlen_offsets.iter().enumerate() {
1390                let cos = cos.narrow(0, *offset, seq_len)?;
1391                let sin = sin.narrow(0, *offset, seq_len)?;
1392                let q_embed = candle_nn::rotary_emb::rope(
1393                    &q_rot.i(i)?.unsqueeze(0)?.contiguous()?,
1394                    &cos,
1395                    &sin,
1396                )?;
1397                let k_embed = candle_nn::rotary_emb::rope(
1398                    &k_rot.i(i)?.unsqueeze(0)?.contiguous()?,
1399                    &cos,
1400                    &sin,
1401                )?;
1402                q_embeds.push(q_embed);
1403                k_embeds.push(k_embed);
1404            }
1405            let q_rot = Tensor::cat(&q_embeds, 0)?;
1406            let k_rot = Tensor::cat(&k_embeds, 0)?;
1407            (q_rot, k_rot)
1408        };
1409
1410        Ok((
1411            Tensor::cat(&[q_rot, q_pass], D::Minus1)?.contiguous()?,
1412            Tensor::cat(&[k_rot, k_pass], D::Minus1)?.contiguous()?,
1413        ))
1414    }
1415}
1416
1417#[derive(Debug, Clone)]
1418pub struct Gemma3RotaryEmbedding(RotaryEmbedding);
1419
1420#[derive(Debug, Clone, Deserialize, Serialize)]
1421#[serde(rename_all = "lowercase")]
1422pub enum Gemma3ScaledRopeType {
1423    #[serde(alias = "linear")]
1424    Linear,
1425}
1426
1427#[derive(Debug, Clone, Deserialize, Serialize)]
1428pub struct Gemma3RopeScalingConfig {
1429    factor: f64,
1430    rope_type: Gemma3ScaledRopeType,
1431}
1432
1433impl Gemma3RotaryEmbedding {
1434    fn new_linear(
1435        cfg: &Gemma3TextConfig,
1436        factor: f64,
1437        is_gpt_neox: bool,
1438        dtype: DType,
1439        dev: &Device,
1440    ) -> Result<Self> {
1441        let max_seq_len = cfg.max_position_embeddings;
1442        let dim = cfg.head_dim;
1443
1444        let inv_freq: Vec<_> = (0..dim)
1445            .step_by(2)
1446            .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
1447            .collect();
1448        let inv_freq_len = inv_freq.len();
1449        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
1450        let inv_freq = (inv_freq / factor)?;
1451
1452        let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
1453            .to_dtype(DType::F32)?
1454            .reshape((max_seq_len, 1))?;
1455        let freqs = t.matmul(&inv_freq)?;
1456        let sin = freqs.sin()?.to_dtype(dtype)?;
1457        let cos = freqs.cos()?.to_dtype(dtype)?;
1458        Ok(Self(RotaryEmbedding {
1459            cos,
1460            sin,
1461            is_gpt_neox,
1462        }))
1463    }
1464
1465    pub fn new(
1466        is_gpt_neox: bool,
1467        dtype: DType,
1468        cfg: &Gemma3TextConfig,
1469        dev: &Device,
1470    ) -> Result<Self> {
1471        match &cfg.rope_scaling {
1472            Some(Gemma3RopeScalingConfig {
1473                rope_type: Gemma3ScaledRopeType::Linear,
1474                factor,
1475            }) => Self::new_linear(cfg, *factor, is_gpt_neox, dtype, dev),
1476
1477            _ => Self::new_linear(cfg, 1.0, is_gpt_neox, dtype, dev),
1478        }
1479    }
1480
1481    pub fn forward(
1482        &self,
1483        q: &Tensor,
1484        k: &Tensor,
1485        seqlen_offsets: &[usize],
1486    ) -> Result<(Tensor, Tensor)> {
1487        self.0.forward(q, k, seqlen_offsets)
1488    }
1489}
1490
1491pub struct DiaRotaryEmbedding {
1492    timescale: Tensor,
1493    dtype: DType,
1494}
1495
1496impl DiaRotaryEmbedding {
1497    pub fn new(
1498        min_timescale: f32,
1499        max_timescale: f32,
1500        head_dim: usize,
1501        device: &Device,
1502        dtype: DType,
1503    ) -> Result<Self> {
1504        assert_eq!(head_dim % 2, 0);
1505        let half_embedding_dim = head_dim / 2;
1506
1507        let fraction = (0..half_embedding_dim).map(|i| 2f32 * i as f32 / head_dim as f32);
1508        let timescale = fraction
1509            .into_iter()
1510            .map(|x| min_timescale * (max_timescale / min_timescale).powf(x))
1511            .collect::<Vec<_>>();
1512
1513        let timescale_len = timescale.len();
1514        let timescale = Tensor::from_vec(timescale, timescale_len, device)?;
1515
1516        Ok(Self { timescale, dtype })
1517    }
1518
1519    pub fn forward(&self, xs: &Tensor, positions: &Tensor) -> Result<Tensor> {
1520        let freqs = positions
1521            .unsqueeze(D::Minus1)?
1522            .unsqueeze(D::Minus1)?
1523            .broadcast_div(&self.timescale)?;
1524
1525        let sin = freqs.sin()?.to_dtype(self.dtype)?;
1526        let cos = freqs.cos()?.to_dtype(self.dtype)?;
1527
1528        let split = xs.chunk(2, D::Minus1)?;
1529        let first_half = &split[0];
1530        let second_half = &split[1];
1531
1532        let first_part = (first_half.broadcast_mul(&cos)? - second_half.broadcast_mul(&sin)?)?;
1533        let second_part = (second_half.broadcast_mul(&cos)? + first_half.broadcast_mul(&sin)?)?;
1534
1535        Tensor::cat(&[first_part, second_part], D::Minus1)
1536    }
1537}
1538#[derive(Debug, Clone)]
1539pub struct QLinear {
1540    inner: QMatMul,
1541    bias: Option<Tensor>,
1542    dtype: DType,
1543}
1544
1545impl QLinear {
1546    pub fn new<R: std::io::Read + std::io::Seek>(
1547        ct: &mut Content<'_, R>,
1548        name: &str,
1549        device: &Device,
1550    ) -> Result<Self> {
1551        let w = ct.tensor(&format!("{name}.weight"), device)?;
1552        let b = ct.tensor(&format!("{name}.bias"), device)?;
1553        let inner = QMatMul::from_qtensor(w)?;
1554        let bias = b.dequantize(device)?;
1555        Ok(Self {
1556            inner,
1557            bias: Some(bias),
1558            dtype: DType::F32,
1559        })
1560    }
1561
1562    pub fn from_linear(linear: Linear) -> Self {
1563        Self {
1564            inner: QMatMul::Tensor(linear.weight().clone()),
1565            bias: linear.bias().cloned(),
1566            dtype: linear.weight().dtype(),
1567        }
1568    }
1569
1570    pub fn from_parts(w: Tensor, b: Option<Tensor>) -> Self {
1571        let dtype = w.dtype();
1572        Self {
1573            inner: QMatMul::Tensor(w),
1574            bias: b,
1575            dtype,
1576        }
1577    }
1578
1579    pub fn from_qparts(w: QTensor, b: Option<Tensor>) -> Self {
1580        if let Some(ref b) = b {
1581            assert_eq!(b.dtype(), DType::F32);
1582        }
1583        Self {
1584            inner: QMatMul::QTensor(Arc::new(w)),
1585            bias: b,
1586            dtype: DType::F32,
1587        }
1588    }
1589
1590    pub fn from_old_and_qmatmul(inner: QMatMul, old: &Self) -> Self {
1591        Self {
1592            inner,
1593            bias: old.bias.clone(),
1594            dtype: old.dtype,
1595        }
1596    }
1597
1598    pub fn inner(&mut self) -> &mut QMatMul {
1599        &mut self.inner
1600    }
1601
1602    pub fn inner_ref(&self) -> &QMatMul {
1603        &self.inner
1604    }
1605
1606    pub fn is_quant(&self) -> bool {
1607        matches!(self.inner, QMatMul::QTensor(_))
1608    }
1609
1610    pub fn bias(&self) -> Option<&Tensor> {
1611        self.bias.as_ref()
1612    }
1613
1614    pub fn bias_mut(&mut self) -> Option<&mut Tensor> {
1615        self.bias.as_mut()
1616    }
1617}
1618
1619impl Module for QLinear {
1620    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
1621        let xs = if self.is_quant() {
1622            xs.to_dtype(DType::F32)?
1623        } else {
1624            xs.clone()
1625        };
1626        if let Some(bias) = &self.bias {
1627            self.inner
1628                .forward(&xs)?
1629                .broadcast_add(bias)?
1630                .to_dtype(self.dtype)
1631        } else {
1632            self.inner.forward(&xs)?.to_dtype(self.dtype)
1633        }
1634    }
1635}
1636
1637#[derive(Debug, Clone)]
1638pub struct RotaryEmbedding {
1639    cos: Tensor,
1640    sin: Tensor,
1641    is_gpt_neox: bool,
1642}
1643
1644impl RotaryEmbedding {
1645    pub fn new(
1646        base: f32,
1647        head_dim: usize,
1648        max_position_embeddings: usize,
1649        device: &Device,
1650        is_gpt_neox: bool,
1651        dtype: DType,
1652    ) -> Result<Self> {
1653        let inv_freq: Vec<_> = (0..head_dim)
1654            .step_by(2)
1655            .map(|i| 1f32 / base.powf(i as f32 / head_dim as f32))
1656            .collect();
1657        let inv_freq_len = inv_freq.len();
1658        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), device)?;
1659        let t = Tensor::arange(0u32, max_position_embeddings as u32, device)?
1660            .to_dtype(DType::F32)?
1661            .reshape((max_position_embeddings, 1))?;
1662        let freqs = t.matmul(&inv_freq)?;
1663        let sin = freqs.sin()?.to_dtype(dtype)?;
1664        let cos = freqs.cos()?.to_dtype(dtype)?;
1665
1666        Ok(Self {
1667            cos,
1668            sin,
1669            is_gpt_neox,
1670        })
1671    }
1672
1673    pub fn new_partial(
1674        base: f32,
1675        rot_dim: usize,
1676        max_position_embeddings: usize,
1677        device: &Device,
1678        is_gpt_neox: bool,
1679        dtype: DType,
1680    ) -> Result<Self> {
1681        let inv_freq: Vec<_> = (0..rot_dim)
1682            .step_by(2)
1683            .map(|i| 1f32 / base.powf(i as f32 / rot_dim as f32))
1684            .collect();
1685        let inv_freq_len = inv_freq.len();
1686        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), device)?;
1687        let t = Tensor::arange(0u32, max_position_embeddings as u32, device)?
1688            .to_dtype(DType::F32)?
1689            .reshape((max_position_embeddings, 1))?;
1690        let freqs = t.matmul(&inv_freq)?;
1691        let sin = freqs.sin()?.to_dtype(dtype)?;
1692        let cos = freqs.cos()?.to_dtype(dtype)?;
1693
1694        Ok(Self {
1695            cos,
1696            sin,
1697            is_gpt_neox,
1698        })
1699    }
1700
1701    pub fn forward(
1702        &self,
1703        q: &Tensor,
1704        k: &Tensor,
1705        seqlen_offsets: &[usize],
1706    ) -> Result<(Tensor, Tensor)> {
1707        let (b_sz, qh, seq_len, n_embd) = q.dims4()?;
1708        let (_b_sz, kh, _seq_len, __n_embd) = k.dims4()?;
1709
1710        let rope = if self.is_gpt_neox {
1711            candle_nn::rotary_emb::rope
1712        } else {
1713            candle_nn::rotary_emb::rope_i
1714        };
1715
1716        if cfg!(feature = "cuda") && qh == kh {
1717            let (cos, sin) = if seqlen_offsets.len() == 1 {
1718                (
1719                    self.cos.narrow(0, seqlen_offsets[0], seq_len)?,
1720                    self.sin.narrow(0, seqlen_offsets[0], seq_len)?,
1721                )
1722            } else {
1723                let mut cos_s = Vec::new();
1724                let mut sin_s = Vec::new();
1725                for offset in seqlen_offsets {
1726                    cos_s.push(self.cos.narrow(0, *offset, seq_len)?);
1727                    sin_s.push(self.sin.narrow(0, *offset, seq_len)?);
1728                }
1729                (Tensor::cat(&cos_s, 0)?, Tensor::cat(&sin_s, 0)?)
1730            };
1731
1732            let q_embed = q.transpose(1, 2)?.flatten(0, 1)?;
1733            let k_embed = k.transpose(1, 2)?.flatten(0, 1)?;
1734            mistralrs_quant::rotary::apply_rotary_inplace(
1735                &q_embed,
1736                &k_embed,
1737                &cos,
1738                &sin,
1739                self.is_gpt_neox,
1740            )?;
1741            let mut q = q_embed
1742                .reshape((b_sz, seq_len, qh, n_embd))?
1743                .transpose(1, 2)?;
1744            let mut k = k_embed
1745                .reshape((b_sz, seq_len, kh, n_embd))?
1746                .transpose(1, 2)?;
1747            if !(cfg!(feature = "flash-attn") || cfg!(feature = "flash-attn-v3")) {
1748                q = q.contiguous()?;
1749                k = k.contiguous()?;
1750            }
1751            Ok((q, k))
1752        } else if seqlen_offsets.len() == 1 {
1753            let cos = self.cos.narrow(0, seqlen_offsets[0], seq_len)?;
1754            let sin = self.sin.narrow(0, seqlen_offsets[0], seq_len)?;
1755            let q_embed = rope(&q.contiguous()?, &cos, &sin)?;
1756            let k_embed = rope(&k.contiguous()?, &cos, &sin)?;
1757            Ok((q_embed, k_embed))
1758        } else {
1759            let mut q_embeds = Vec::new();
1760            let mut k_embeds = Vec::new();
1761            for (i, offset) in seqlen_offsets.iter().enumerate() {
1762                let cos = self.cos.narrow(0, *offset, seq_len)?;
1763                let sin = self.sin.narrow(0, *offset, seq_len)?;
1764                let q_embed = rope(&q.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
1765                let k_embed = rope(&k.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
1766                q_embeds.push(q_embed);
1767                k_embeds.push(k_embed);
1768            }
1769            Ok((Tensor::cat(&q_embeds, 0)?, Tensor::cat(&k_embeds, 0)?))
1770        }
1771    }
1772}
1773
1774#[derive(Debug, Clone, Copy, PartialEq, Deserialize, Serialize, Default)]
1775#[serde(rename_all = "lowercase")]
1776pub enum Activation {
1777    #[default]
1778    #[serde(alias = "gelu")]
1779    Gelu,
1780    #[serde(alias = "gelu_new")]
1781    NewGelu,
1782    Relu,
1783    Relu2,
1784    Relu6,
1785    Silu,
1786    Sigmoid,
1787    HardSigmoid,
1788    Swiglu,
1789    Swish,
1790    HardSwish,
1791    Elu(f64),
1792    LeakyRelu(f64),
1793    #[serde(alias = "gelu_pytorch_tanh")]
1794    GeluPytorchTanh,
1795    QuickGelu,
1796}
1797
1798impl Module for Activation {
1799    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
1800        match self {
1801            Self::Gelu => xs.gelu_erf(),
1802            // https://github.com/huggingface/transformers/blob/12f043eaeaabfef6f6efea411d98e6f6d3c094b7/src/transformers/activations.py#L49-L78
1803            Self::NewGelu => xs.gelu(),
1804            Self::Relu => xs.relu(),
1805            Self::Relu2 => xs.relu()?.sqr(),
1806            Self::Relu6 => xs.clamp(0f32, 6f32),
1807            Self::Silu => xs.silu(),
1808            Self::Sigmoid => candle_nn::ops::sigmoid(xs),
1809            Self::HardSigmoid => candle_nn::ops::hard_sigmoid(xs),
1810            Self::Swiglu => candle_nn::ops::swiglu(xs),
1811            Self::Swish => xs * candle_nn::ops::sigmoid(xs)?,
1812            Self::HardSwish => xs * candle_nn::ops::hard_sigmoid(xs)?,
1813            &Self::Elu(alpha) => xs.elu(alpha),
1814            &Self::LeakyRelu(negative_slope) => candle_nn::ops::leaky_relu(xs, negative_slope),
1815            Self::GeluPytorchTanh => xs.gelu(),
1816            Self::QuickGelu => xs * candle_nn::ops::sigmoid(&(xs * 1.702f64)?),
1817        }
1818    }
1819}
1820
1821impl TryInto<candle_nn::Activation> for Activation {
1822    type Error = candle_core::Error;
1823
1824    fn try_into(self) -> Result<candle_nn::Activation> {
1825        match self {
1826            Self::Gelu => Ok(candle_nn::Activation::Gelu),
1827            Self::Relu => Ok(candle_nn::Activation::Relu),
1828            Self::Silu => Ok(candle_nn::Activation::Silu),
1829            Self::NewGelu => Ok(candle_nn::Activation::NewGelu),
1830            Self::Relu2 => Ok(candle_nn::Activation::Relu2),
1831            Self::Relu6 => Ok(candle_nn::Activation::Relu6),
1832            Self::Sigmoid => Ok(candle_nn::Activation::Sigmoid),
1833            Self::HardSigmoid => Ok(candle_nn::Activation::HardSigmoid),
1834            Self::Swiglu => Ok(candle_nn::Activation::Swiglu),
1835            Self::Swish => Ok(candle_nn::Activation::Swish),
1836            Self::HardSwish => Ok(candle_nn::Activation::HardSwish),
1837            Self::Elu(x) => Ok(candle_nn::Activation::Elu(x)),
1838            Self::LeakyRelu(x) => Ok(candle_nn::Activation::LeakyRelu(x)),
1839            Self::GeluPytorchTanh => Ok(candle_nn::Activation::GeluPytorchTanh),
1840            Self::QuickGelu => candle_core::bail!("No mapping to candle_nn for QuickGelu"),
1841        }
1842    }
1843}
1844
1845#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1846pub struct Conv3dConfig {
1847    pub padding: usize,
1848    pub stride: usize,
1849    pub dilation: usize,
1850    pub groups: usize,
1851}
1852
1853impl Default for Conv3dConfig {
1854    fn default() -> Self {
1855        Self {
1856            padding: 0,
1857            stride: 1,
1858            dilation: 1,
1859            groups: 1,
1860        }
1861    }
1862}
1863
1864pub struct Conv3dNoBias {
1865    conv2d_1: Conv2d,
1866    conv2d_2: Conv2d,
1867}
1868
1869impl Conv3dNoBias {
1870    pub fn new(
1871        in_channels: usize,
1872        out_channels: usize,
1873        kernel_sizes: [usize; 3],
1874        cfg: Conv3dConfig,
1875        vb: ShardedVarBuilder,
1876    ) -> Result<Self> {
1877        let ws = vb.get(
1878            (
1879                out_channels,
1880                in_channels / cfg.groups,
1881                kernel_sizes[0],
1882                kernel_sizes[1],
1883                kernel_sizes[2],
1884            ),
1885            "weight",
1886        )?;
1887
1888        // Split on temporal dimension
1889        // https://github.com/pytorch/pytorch/issues/139066
1890
1891        let w1 = ws.i((.., .., 0, .., ..))?;
1892        let w2 = ws.i((.., .., 1, .., ..))?;
1893
1894        let cfg = Conv2dConfig {
1895            padding: cfg.padding,
1896            stride: cfg.stride,
1897            dilation: cfg.dilation,
1898            groups: cfg.groups,
1899        };
1900
1901        Ok(Self {
1902            conv2d_1: Conv2d::new(w1.contiguous()?, None, cfg),
1903            conv2d_2: Conv2d::new(w2.contiguous()?, None, cfg),
1904        })
1905    }
1906}
1907
1908impl Module for Conv3dNoBias {
1909    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
1910        let xs1 = xs.i((.., .., 0, .., ..))?;
1911        let xs2 = xs.i((.., .., 1, .., ..))?;
1912
1913        (self.conv2d_1.forward(&xs1)? + self.conv2d_2.forward(&xs2)?)?.unsqueeze(2)
1914    }
1915}
1916
1917pub trait TensorInfExtend {
1918    fn is_inf(&self) -> Result<Self>
1919    where
1920        Self: Sized;
1921    fn any(&self) -> Result<bool>;
1922}
1923
1924impl TensorInfExtend for Tensor {
1925    fn is_inf(&self) -> Result<Self> {
1926        self.broadcast_eq(&Tensor::new(f32::INFINITY, self.device())?.to_dtype(self.dtype())?)
1927    }
1928
1929    fn any(&self) -> Result<bool> {
1930        let sum = self.sum_all()?;
1931        match self.dtype() {
1932            DType::U8 => Ok(sum.to_scalar::<u8>()? == 0),
1933            DType::U32 => Ok(sum.to_scalar::<u32>()? == 0),
1934            DType::I16 => Ok(sum.to_scalar::<i16>()? == 0),
1935            DType::I32 => Ok(sum.to_scalar::<i32>()? == 0),
1936            DType::I64 => Ok(sum.to_scalar::<i64>()? == 0),
1937            DType::F16 => Ok(sum.to_scalar::<half::f16>()? == half::f16::from_f32_const(0.)),
1938            DType::BF16 => Ok(sum.to_scalar::<half::bf16>()? == half::bf16::from_f32_const(0.)),
1939            DType::F32 => Ok(sum.to_scalar::<f32>()? == 0.),
1940            DType::F64 => Ok(sum.to_scalar::<f64>()? == 0.),
1941            DType::F8E4M3 => Ok(sum.to_scalar::<F8E4M3>()? == F8E4M3::ZERO),
1942        }
1943    }
1944}
1945
1946pub fn clamp_for_f16(xs: &Tensor) -> Result<Tensor> {
1947    let mut max = match xs.dtype() {
1948        DType::U8 => u8::MAX as f32 - 1000.,
1949        DType::U32 => u32::MAX as f32 - 1000.,
1950        DType::I16 => i16::MAX as f32 - 1000.,
1951        DType::I32 => i32::MAX as f32 - 1000.,
1952        DType::I64 => i64::MAX as f32 - 1000.,
1953        DType::F16 => half::f16::MAX.to_f32_const() - 1000.,
1954        DType::BF16 => half::bf16::MAX.to_f32_const() - 1000.,
1955        DType::F32 => f32::MAX - 1000.,
1956        DType::F64 => f64::MAX as f32 - 1000.,
1957        DType::F8E4M3 => F8E4M3::MAX.to_f32() - 1000.,
1958    };
1959    if xs.is_inf()?.any()? {
1960        max -= 1000.;
1961    }
1962    xs.clamp(-max, max)
1963}
1964
1965pub struct FloatInfo {
1966    /// Minimum representable value.
1967    pub min: f64,
1968    /// Maximum representable value.
1969    pub max: f64,
1970    /// The difference between 1.0 and the next smallest representable float larger than 1.0.
1971    pub eps: f64,
1972    pub dtype: DType,
1973}
1974
1975pub trait GetFloatInfo {
1976    fn finfo(&self) -> Result<FloatInfo>;
1977}
1978
1979impl GetFloatInfo for DType {
1980    fn finfo(&self) -> Result<FloatInfo> {
1981        let finfo = match self {
1982            Self::BF16 => FloatInfo {
1983                min: bf16::MIN.to_f64(),
1984                max: bf16::MAX.to_f64(),
1985                eps: bf16::EPSILON.to_f64(),
1986                dtype: DType::BF16,
1987            },
1988            Self::F16 => FloatInfo {
1989                min: f16::MIN.to_f64(),
1990                max: f16::MAX.to_f64(),
1991                eps: f16::EPSILON.to_f64(),
1992                dtype: DType::F16,
1993            },
1994            Self::F32 => FloatInfo {
1995                min: f32::MIN as f64,
1996                max: f32::MAX as f64,
1997                eps: f32::EPSILON as f64,
1998                dtype: DType::F32,
1999            },
2000            Self::F64 => FloatInfo {
2001                min: f64::MIN,
2002                max: f64::MAX,
2003                eps: f64::EPSILON,
2004                dtype: DType::F64,
2005            },
2006            Self::F8E4M3 => FloatInfo {
2007                min: F8E4M3::MIN.to_f64(),
2008                max: F8E4M3::MAX.to_f64(),
2009                eps: F8E4M3::EPSILON.to_f64(),
2010                dtype: DType::F8E4M3,
2011            },
2012            other => {
2013                candle_core::bail!("Expected a float type for `GetFloatInfo`, got {other:?}");
2014            }
2015        };
2016        Ok(finfo)
2017    }
2018}
2019
2020#[derive(Clone)]
2021pub struct Mlp {
2022    pub gate: Arc<dyn QuantMethod>,
2023    pub up: Arc<dyn QuantMethod>,
2024    pub down: Arc<dyn QuantMethod>,
2025    act: Activation,
2026    params: Vec<usize>,
2027}
2028
2029impl Mlp {
2030    pub fn new(
2031        vb: ShardedVarBuilder,
2032        hidden_size: usize,
2033        intermediate_size: usize,
2034        quantization_config: &Option<QuantizedConfig>,
2035        hidden_act: Activation,
2036        comm: &Arc<mistralrs_quant::Comm>,
2037    ) -> Result<Self> {
2038        Ok(Self {
2039            gate: ColumnParallelLayer::new(
2040                hidden_size,
2041                intermediate_size,
2042                quantization_config,
2043                false,
2044                comm,
2045                vb.pp("gate_proj"),
2046            )?,
2047            up: ColumnParallelLayer::new(
2048                hidden_size,
2049                intermediate_size,
2050                quantization_config,
2051                false,
2052                comm,
2053                vb.pp("up_proj"),
2054            )?,
2055            down: RowParallelLayer::new(
2056                intermediate_size,
2057                hidden_size,
2058                quantization_config,
2059                false,
2060                comm,
2061                vb.pp("down_proj"),
2062            )?,
2063            act: hidden_act,
2064            params: vec![hidden_size, intermediate_size],
2065        })
2066    }
2067
2068    pub fn replicate(
2069        params: &[usize],
2070        vb: ShardedVarBuilder,
2071        act: Activation,
2072        comm: &Arc<mistralrs_quant::Comm>,
2073    ) -> Result<Self> {
2074        Self::new(vb, params[0], params[1], &None, act, comm)
2075    }
2076
2077    pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2078        let original_dtype = xs.dtype();
2079        let mut xs = xs.clone();
2080        if let Some(t) = self.gate.quantized_act_type() {
2081            xs = xs.to_dtype(t)?;
2082        }
2083        let lhs = self.gate.forward(&xs)?;
2084        let rhs = self.up.forward(&xs)?;
2085        let mut res = self.down.forward(&candle_nn::ops::mul_and_act(
2086            &lhs,
2087            &rhs,
2088            self.act.try_into()?,
2089        )?)?;
2090        if self.gate.quantized_act_type().is_some() {
2091            res = res.to_dtype(original_dtype)?;
2092        }
2093        Ok(res)
2094    }
2095}
2096
2097impl AnyMoeTrainableLayer for Mlp {}
2098
2099impl MlpLayer for Mlp {
2100    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2101        let original_dtype = xs.dtype();
2102        let mut xs = xs.clone();
2103        if let Some(t) = self.gate.quantized_act_type() {
2104            xs = xs.to_dtype(t)?;
2105        }
2106        let lhs = MatMul.qmethod_matmul(&xs, &*self.gate)?;
2107        let rhs = MatMul.qmethod_matmul(&xs, &*self.up)?;
2108        let mut res = if matches!(
2109            self.act,
2110            Activation::Gelu | Activation::Silu | Activation::Relu
2111        ) {
2112            MatMul.qmethod_matmul(
2113                &candle_nn::ops::mul_and_act(&lhs, &rhs, self.act.try_into()?)?,
2114                &*self.down,
2115            )?
2116        } else {
2117            MatMul.qmethod_matmul(&(&lhs.apply(&self.act)? * &rhs)?, &*self.down)?
2118        };
2119        if self.gate.quantized_act_type().is_some() {
2120            res = res.to_dtype(original_dtype)?;
2121        }
2122        Ok(res)
2123    }
2124    fn get_isq_layers(&mut self) -> Vec<&mut Arc<dyn QuantMethod>> {
2125        vec![&mut self.gate, &mut self.up, &mut self.down]
2126    }
2127    fn clone(&self) -> Box<dyn MlpLayer> {
2128        Box::new(Clone::clone(self))
2129    }
2130    fn get_params(&self) -> &[usize] {
2131        &self.params
2132    }
2133    fn hidden_act(&self) -> Activation {
2134        self.act
2135    }
2136    // gate, up, down
2137    fn new_added_delta(&self, deltas: Vec<Option<Tensor>>) -> Result<Box<dyn MlpLayer>> {
2138        let gate = if let Some(ref delta) = deltas[0] {
2139            self.gate.add_delta_w(delta)?
2140        } else {
2141            self.gate.clone()
2142        };
2143        let up = if let Some(ref delta) = deltas[1] {
2144            self.up.add_delta_w(delta)?
2145        } else {
2146            self.up.clone()
2147        };
2148        let down = if let Some(ref delta) = deltas[2] {
2149            self.down.add_delta_w(delta)?
2150        } else {
2151            self.down.clone()
2152        };
2153
2154        Ok(Box::new(Self {
2155            gate,
2156            up,
2157            down,
2158            act: self.act,
2159            params: self.params.clone(),
2160        }))
2161    }
2162
2163    fn dtype_device(&self) -> (DType, Device) {
2164        self.gate.dtype_and_device()
2165    }
2166}
2167
2168pub struct AvgPool2d {
2169    kernel_size: usize,
2170    stride: usize,
2171}
2172
2173impl AvgPool2d {
2174    pub fn new(kernel_size: usize, stride: usize) -> Self {
2175        Self {
2176            kernel_size,
2177            stride,
2178        }
2179    }
2180
2181    pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2182        xs.avg_pool2d_with_stride(self.kernel_size, self.stride)
2183    }
2184}
2185
2186/// Applies 2D reflection padding to a tensor of shape (N, C, H, W).
2187///
2188/// The `padding` argument is a 4-tuple (pad_left, pad_right, pad_top, pad_bottom).
2189/// For left padding, it reflects the values from column 1 up to pad_left (in reverse order);
2190/// for right padding, it reflects from the second-to-last column backwards, and similarly for
2191/// vertical (height) padding.
2192pub struct ReflectionPad2d {
2193    padding: (usize, usize, usize, usize),
2194}
2195
2196impl ReflectionPad2d {
2197    pub fn new(padding: (usize, usize, usize, usize)) -> Self {
2198        Self { padding }
2199    }
2200}
2201
2202impl Module for ReflectionPad2d {
2203    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2204        let (pad_left, pad_right, pad_top, pad_bottom) = self.padding;
2205
2206        let (_n, _c, h, w) = xs.dims4()?;
2207
2208        // --- Horizontal Padding (along width, axis = 3) ---
2209        // For left padding, we reflect columns 1..=pad_left (in reverse order).
2210        let left_pad = if pad_left > 0 {
2211            // Create indices: [pad_left, pad_left-1, ..., 1]
2212            let indices: Vec<i64> = (1..=pad_left as i64).rev().collect();
2213            Some(xs.index_select(&Tensor::new(indices, &Device::Cpu)?, 3)?)
2214        } else {
2215            None
2216        };
2217
2218        // For right padding, we reflect from the right side (excluding the last column).
2219        let right_pad = if pad_right > 0 {
2220            // For pad_right == 2, generate indices: [w-2, w-3, ... , w-1-pad_right]
2221            let start = w as i64 - 2;
2222            let indices: Vec<i64> = (0..pad_right as i64).map(|i| start - i).collect();
2223            Some(xs.index_select(&Tensor::new(indices, &Device::Cpu)?, 3)?)
2224        } else {
2225            None
2226        };
2227
2228        // Concatenate horizontally (along width, dim=3)
2229        let x_padded_width = match (left_pad, right_pad) {
2230            (Some(l), Some(r)) => Tensor::cat(&[l, xs.clone(), r], 3)?,
2231            (Some(l), None) => Tensor::cat(&[l, xs.clone()], 3)?,
2232            (None, Some(r)) => Tensor::cat(&[xs.clone(), r], 3)?,
2233            (None, None) => xs.clone(),
2234        };
2235
2236        // --- Vertical Padding (along height, axis = 2) ---
2237        // For top padding, reflect rows 1..=pad_top (in reverse order)
2238        let top_pad = if pad_top > 0 {
2239            let indices: Vec<i64> = (1..=pad_top as i64).rev().collect();
2240            Some(x_padded_width.index_select(&Tensor::new(indices, &Device::Cpu)?, 2)?)
2241        } else {
2242            None
2243        };
2244
2245        // For bottom padding, reflect from the bottom (excluding the last row)
2246        let bottom_pad = if pad_bottom > 0 {
2247            let start = h as i64 - 2;
2248            let indices: Vec<i64> = (0..pad_bottom as i64).map(|i| start - i).collect();
2249            Some(x_padded_width.index_select(&Tensor::new(indices, &Device::Cpu)?, 2)?)
2250        } else {
2251            None
2252        };
2253
2254        // Concatenate vertically (along height, dim=2)
2255        let x_padded = match (top_pad, bottom_pad) {
2256            (Some(t), Some(b)) => Tensor::cat(&[t, x_padded_width, b], 2)?,
2257            (Some(t), None) => Tensor::cat(&[t, x_padded_width], 2)?,
2258            (None, Some(b)) => Tensor::cat(&[x_padded_width, b], 2)?,
2259            (None, None) => x_padded_width,
2260        };
2261
2262        Ok(x_padded)
2263    }
2264}
2265
2266pub struct ScaledEmbedding {
2267    scale: f64,
2268    embedding: Embedding,
2269}
2270
2271impl ScaledEmbedding {
2272    pub fn new(scale: f64, embedding: Embedding) -> Self {
2273        Self { scale, embedding }
2274    }
2275
2276    pub fn embeddings(&self) -> &Tensor {
2277        self.embedding.embeddings()
2278    }
2279}
2280
2281impl Module for ScaledEmbedding {
2282    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2283        xs.apply(&self.embedding)? * self.scale
2284    }
2285}