mistralrs_core/
layers.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{f32::consts::PI, ops::Mul, str::FromStr, sync::Arc};
4
5use candle_core::{
6    quantized::{QMatMul, QTensor},
7    Context, DType, Device, IndexOp, Result, Tensor, D,
8};
9use candle_nn::{
10    BatchNorm, BatchNormConfig, Conv1d, Conv1dConfig, Conv2d, Conv2dConfig, Embedding, GroupNorm,
11    LayerNorm, LayerNormConfig, Linear, Module,
12};
13use float8::F8E4M3;
14use half::{bf16, f16};
15use mistralrs_quant::{
16    AfqLayer, ColumnParallelLayer, QuantMethod, QuantizedConfig, RowParallelLayer,
17    ShardedVarBuilder,
18};
19use serde::{Deserialize, Serialize};
20
21pub use crate::attention::Sdpa;
22pub use crate::layers_masker::CausalMasker;
23pub use crate::layers_utils::repeat_kv;
24use crate::{
25    amoe::{AnyMoeTrainableLayer, MlpLayer},
26    gguf::Content,
27    models::llama,
28    ops::SplitOp,
29    vision_models::{
30        gemma3::config::Gemma3TextConfig,
31        llama4,
32        mllama::{MLlamaRopeScaling, MLlamaRopeType, MLlamaTextConfig},
33        phi4::Phi4MMConfig,
34    },
35};
36
37pub use mistralrs_quant::MatMul;
38
39pub fn embedding(
40    in_size: usize,
41    out_size: usize,
42    vb: ShardedVarBuilder,
43    config: &Option<QuantizedConfig>,
44) -> Result<Embedding> {
45    // AFQ quantized applies quantization to the embeddings.
46    let embeddings = if let Some(QuantizedConfig::Afq { .. }) = config {
47        let afq_layer =
48            AfqLayer::afq_linear_b(out_size, in_size, config.as_ref().unwrap(), false, vb)?;
49        afq_layer.dequantize_w()?
50    } else {
51        vb.get_with_hints((in_size, out_size), "weight", Default::default())?
52    };
53    Ok(Embedding::new(embeddings, out_size))
54}
55
56pub fn layer_norm<C: Into<LayerNormConfig>>(
57    size: usize,
58    config: C,
59    vb: ShardedVarBuilder,
60) -> Result<LayerNorm> {
61    let config = config.into();
62    let weight = vb.get(size, "weight")?;
63    if config.affine {
64        let bias = vb.get(size, "bias")?;
65        Ok(LayerNorm::new(weight, bias, config.eps))
66    } else {
67        Ok(LayerNorm::new_no_bias(weight, config.eps))
68    }
69}
70
71pub fn batch_norm<C: Into<BatchNormConfig>>(
72    num_features: usize,
73    config: C,
74    vb: ShardedVarBuilder,
75) -> Result<BatchNorm> {
76    let config = config.into();
77    if config.eps < 0. {
78        candle_core::bail!("batch-norm eps cannot be negative {}", config.eps)
79    }
80    let running_mean = vb.get(num_features, "running_mean")?;
81    let running_var = vb.get(num_features, "running_var")?;
82
83    if config.affine {
84        let weight = vb.get(num_features, "weight")?;
85        let bias = vb.get(num_features, "bias")?;
86        BatchNorm::new(
87            num_features,
88            running_mean,
89            running_var,
90            weight,
91            bias,
92            config.eps,
93        )
94    } else {
95        BatchNorm::new_no_bias(num_features, running_mean, running_var, config.eps)
96    }
97}
98
99pub fn group_norm(
100    num_groups: usize,
101    num_channels: usize,
102    eps: f64,
103    vb: ShardedVarBuilder,
104) -> Result<GroupNorm> {
105    let weight = vb.get(num_channels, "weight")?;
106    let bias = vb.get(num_channels, "bias")?;
107    GroupNorm::new(weight, bias, num_channels, num_groups, eps)
108}
109
110pub fn conv2d(
111    in_channels: usize,
112    out_channels: usize,
113    kernel_size: usize,
114    cfg: Conv2dConfig,
115    vb: ShardedVarBuilder,
116) -> Result<Conv2d> {
117    let ws = vb.get(
118        (
119            out_channels,
120            in_channels / cfg.groups,
121            kernel_size,
122            kernel_size,
123        ),
124        "weight",
125    )?;
126    let bs = vb.get(out_channels, "bias")?;
127    Ok(Conv2d::new(ws, Some(bs), cfg))
128}
129
130pub fn conv2d_no_bias(
131    in_channels: usize,
132    out_channels: usize,
133    kernel_size: usize,
134    cfg: Conv2dConfig,
135    vb: ShardedVarBuilder,
136) -> Result<Conv2d> {
137    let ws = vb.get(
138        (
139            out_channels,
140            in_channels / cfg.groups,
141            kernel_size,
142            kernel_size,
143        ),
144        "weight",
145    )?;
146    Ok(Conv2d::new(ws, None, cfg))
147}
148
149pub fn conv1d(
150    in_channels: usize,
151    out_channels: usize,
152    kernel_size: usize,
153    cfg: Conv1dConfig,
154    vb: ShardedVarBuilder,
155) -> Result<Conv1d> {
156    let ws = vb.get(
157        (out_channels, in_channels / cfg.groups, kernel_size),
158        "weight",
159    )?;
160    let bs = vb.get(out_channels, "bias")?;
161    Ok(Conv1d::new(ws, Some(bs), cfg))
162}
163
164pub fn conv1d_no_bias(
165    in_channels: usize,
166    out_channels: usize,
167    kernel_size: usize,
168    cfg: Conv1dConfig,
169    vb: ShardedVarBuilder,
170) -> Result<Conv1d> {
171    let ws = vb.get(
172        (out_channels, in_channels / cfg.groups, kernel_size),
173        "weight",
174    )?;
175    Ok(Conv1d::new(ws, None, cfg))
176}
177
178pub fn linear(in_dim: usize, out_dim: usize, vb: ShardedVarBuilder) -> Result<Linear> {
179    let ws = vb.get((out_dim, in_dim), "weight")?;
180    let bs = vb.get(out_dim, "bias")?;
181    Ok(Linear::new(ws, Some(bs)))
182}
183
184pub fn linear_no_bias(in_dim: usize, out_dim: usize, vb: ShardedVarBuilder) -> Result<Linear> {
185    let ws = vb.get((out_dim, in_dim), "weight")?;
186    Ok(Linear::new(ws, None))
187}
188
189pub fn linear_b(
190    in_dim: usize,
191    out_dim: usize,
192    bias: bool,
193    vb: ShardedVarBuilder,
194) -> Result<Linear> {
195    if bias {
196        linear(in_dim, out_dim, vb)
197    } else {
198        linear_no_bias(in_dim, out_dim, vb)
199    }
200}
201
202#[derive(Debug, Clone)]
203pub struct RmsNorm {
204    eps: f64,
205    weight: Tensor,
206}
207
208impl RmsNorm {
209    pub fn new(size: usize, eps: f64, vb: ShardedVarBuilder) -> Result<Self> {
210        let w = vb.get(size, "weight")?;
211        Ok(Self { eps, weight: w })
212    }
213
214    /// Gemma uses weight + 1.0
215    pub fn new_gemma(size: usize, eps: f64, vb: ShardedVarBuilder) -> Result<Self> {
216        let w = vb.get(size, "weight")?;
217        let w = (w + 1.0)?;
218        Ok(Self { eps, weight: w })
219    }
220
221    /// Gemma uses weight + 1.0. Undo for UQFF generation.
222    pub fn undo_gemma(&self) -> Result<Self> {
223        Ok(Self {
224            eps: self.eps,
225            weight: (&self.weight - 1.0)?,
226        })
227    }
228
229    pub fn from_w(w: Tensor, eps: f64) -> Result<Self> {
230        Ok(Self { eps, weight: w })
231    }
232
233    pub fn weight(&self) -> &Tensor {
234        &self.weight
235    }
236}
237
238impl Module for RmsNorm {
239    fn forward(&self, x: &Tensor) -> Result<Tensor> {
240        candle_nn::ops::rms_norm(&x.contiguous()?, &self.weight, self.eps as f32)
241    }
242}
243
244#[derive(Debug, Clone)]
245pub struct F32RmsNorm {
246    w: Tensor,
247    eps: f64,
248}
249
250impl F32RmsNorm {
251    pub fn new(size: usize, eps: f64, vb: ShardedVarBuilder) -> Result<Self> {
252        Ok(Self {
253            w: vb.get((size,), "weight")?,
254            eps,
255        })
256    }
257
258    pub fn weight(&self) -> &Tensor {
259        &self.w
260    }
261}
262
263impl Module for F32RmsNorm {
264    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
265        let initial_type = xs.dtype();
266        let mut xs = xs.to_dtype(DType::F32)?;
267        let var = xs.powf(2.)?.mean_keepdim(D::Minus1)?;
268        xs = xs.broadcast_mul(&(&var + self.eps)?.recip()?.sqrt()?)?;
269        xs.to_dtype(initial_type)?.broadcast_mul(&self.w)
270    }
271}
272
273#[derive(Debug, Clone)]
274pub struct QRmsNorm {
275    eps: f64,
276    weight: Tensor,
277}
278
279impl QRmsNorm {
280    pub fn new(scale: QTensor, eps: f32) -> Result<Self> {
281        let scale = scale.dequantize(&scale.device())?;
282        Ok(Self {
283            eps: eps as f64,
284            weight: scale,
285        })
286    }
287
288    pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
289        candle_nn::ops::rms_norm(&x.contiguous()?, &self.weight, self.eps as f32)
290    }
291}
292
293/// RoPE supporting LongRope
294#[derive(Debug, Clone)]
295pub struct PhiRotaryEmbedding {
296    short_sin: Tensor,
297    short_cos: Tensor,
298    long_cos: Option<Tensor>,
299    long_sin: Option<Tensor>,
300    original_max_position_embeddings: usize,
301}
302
303#[derive(Debug, Clone, Deserialize, Serialize)]
304#[serde(rename_all = "lowercase")]
305pub enum ScaledRopeType {
306    #[serde(alias = "su")]
307    #[serde(alias = "longrope")]
308    Su,
309    #[serde(alias = "yarn")]
310    Yarn,
311    #[serde(alias = "dynamic")]
312    Dynamic,
313    #[serde(alias = "linear")]
314    Linear,
315}
316
317impl FromStr for ScaledRopeType {
318    type Err = candle_core::Error;
319    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
320        match s {
321            "su" | "longrope" => Ok(Self::Su),
322            "yarn" => Ok(Self::Yarn),
323            "linear" => Ok(Self::Linear),
324            "dynamic" => Ok(Self::Dynamic),
325            _ => Err(candle_core::Error::Msg(
326                "Expected either `su` or `yarn` scaled RoPE type.".to_string(),
327            )),
328        }
329    }
330}
331
332#[derive(Debug, Clone, Deserialize, Serialize)]
333#[serde(untagged)]
334pub enum PhiRopeScalingConfig {
335    Classic {
336        short_factor: Vec<f64>,
337        long_factor: Vec<f64>,
338        #[serde(rename = "type")]
339        scaling_type: ScaledRopeType,
340    },
341    Scaled {
342        short_factor: Vec<f64>,
343        long_factor: Vec<f64>,
344        #[serde(rename = "type")]
345        scaling_type: ScaledRopeType,
346        long_mscale: f64,
347        short_mscale: f64,
348    },
349}
350
351pub struct PhiRopeConfig {
352    pub rope_scaling: Option<PhiRopeScalingConfig>,
353    pub max_position_embeddings: usize,
354    pub original_max_position_embeddings: usize,
355    pub rope_theta: f64,
356    pub head_dim: usize,
357    pub partial_rotary_factor: Option<f64>,
358}
359
360impl PhiRotaryEmbedding {
361    fn new_classic_scaled(
362        short_factor: &[f64],
363        long_factor: &[f64],
364        scaling_type: &ScaledRopeType,
365        cfg: &PhiRopeConfig,
366        dtype: DType,
367        dev: &Device,
368    ) -> Result<Self> {
369        let max_seq_len = cfg.max_position_embeddings;
370        let dim = (cfg.head_dim as f64 * cfg.partial_rotary_factor.unwrap_or(1.)) as usize;
371
372        // Calculate scale
373        let scale =
374            cfg.max_position_embeddings as f64 / cfg.original_max_position_embeddings as f64;
375        let scaling_factor = if scale <= 1.0 {
376            1.0
377        } else {
378            match scaling_type {
379                ScaledRopeType::Su => {
380                    (1.0 + scale.ln() / (cfg.original_max_position_embeddings as f64).ln()).sqrt()
381                }
382                ScaledRopeType::Yarn => 0.1 * scale.ln() + 1.0,
383                _ => candle_core::bail!("Expected either `su` or `yarn` RoPE"),
384            }
385        };
386
387        // Calculate inv freqs for short, long
388        let inv_freq_long = (0..dim)
389            .step_by(2)
390            .enumerate()
391            .map(|(k, i)| {
392                (1f64 / (long_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64))) as f32
393            })
394            .collect::<Vec<_>>();
395        let inv_freq_short = (0..dim)
396            .step_by(2)
397            .enumerate()
398            .map(|(k, i)| {
399                (1f64 / (short_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64))) as f32
400            })
401            .collect::<Vec<_>>();
402        let inv_freq_len = inv_freq_long.len();
403
404        let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
405            .to_dtype(DType::F32)?
406            .reshape((max_seq_len, 1))?;
407
408        // Calculate sin,cos for long
409        let inv_freq_long = Tensor::from_vec(inv_freq_long, (1, inv_freq_len), dev)?;
410        let freqs_long = t.matmul(&inv_freq_long)?;
411        let long_sin = freqs_long.sin()?.mul(scaling_factor)?.to_dtype(dtype)?;
412        let long_cos = freqs_long.cos()?.mul(scaling_factor)?.to_dtype(dtype)?;
413
414        // Calculate sin,cos for short
415        let inv_freq_short =
416            Tensor::from_vec(inv_freq_short, (1, inv_freq_len), dev)?.to_dtype(DType::F32)?;
417        let freqs_short = t.matmul(&inv_freq_short)?;
418        let short_sin = freqs_short.sin()?.mul(scaling_factor)?.to_dtype(dtype)?;
419        let short_cos = freqs_short.cos()?.mul(scaling_factor)?.to_dtype(dtype)?;
420
421        Ok(Self {
422            short_cos,
423            short_sin,
424            long_cos: Some(long_cos),
425            long_sin: Some(long_sin),
426            original_max_position_embeddings: cfg.original_max_position_embeddings,
427        })
428    }
429
430    fn new_unscaled(cfg: &PhiRopeConfig, dtype: DType, dev: &Device) -> Result<Self> {
431        let max_seq_len = cfg.max_position_embeddings;
432        let dim = (cfg.head_dim as f64 * cfg.partial_rotary_factor.unwrap_or(1.)) as usize;
433
434        let inv_freq: Vec<_> = (0..dim)
435            .step_by(2)
436            .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
437            .collect();
438        let inv_freq_len = inv_freq.len();
439        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
440        let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
441            .to_dtype(DType::F32)?
442            .reshape((max_seq_len, 1))?;
443        let freqs = t.matmul(&inv_freq)?;
444        let sin = freqs.sin()?.to_dtype(dtype)?;
445        let cos = freqs.cos()?.to_dtype(dtype)?;
446        Ok(Self {
447            short_cos: cos,
448            short_sin: sin,
449            long_cos: None,
450            long_sin: None,
451            original_max_position_embeddings: cfg.original_max_position_embeddings,
452        })
453    }
454
455    #[allow(clippy::too_many_arguments)]
456    fn new_scaled(
457        short_factor: &[f64],
458        long_factor: &[f64],
459        scaling_type: &ScaledRopeType,
460        long_mscale: f64,
461        short_mscale: f64,
462        cfg: &PhiRopeConfig,
463        dtype: DType,
464        dev: &Device,
465    ) -> Result<Self> {
466        let max_seq_len = cfg.max_position_embeddings;
467        let dim = (cfg.head_dim as f64 * cfg.partial_rotary_factor.unwrap_or(1.)) as usize;
468
469        if !matches!(scaling_type, ScaledRopeType::Su) {
470            candle_core::bail!("Scaled Phi3 RoPE (non-classic scaled, with mscales) must have type `su`/`longrope`.");
471        }
472
473        if short_factor.len() != dim / 2 {
474            candle_core::bail!(
475                "Misaligned length {}, expected {} for `su`/`longrope` short rescale factors",
476                short_factor.len(),
477                dim / 2
478            );
479        }
480        if long_factor.len() != dim / 2 {
481            candle_core::bail!(
482                "Misaligned length {}, expected {} for `su`/`longrope` long rescale factors",
483                long_factor.len(),
484                dim / 2
485            );
486        }
487
488        // Short cos/sin
489        let inv_freq_short: Vec<_> = (0..dim)
490            .step_by(2)
491            .enumerate()
492            .map(|(k, i)| {
493                1f32 / (short_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32
494            })
495            .collect();
496        let inv_freq_len_short = inv_freq_short.len();
497        let inv_freq_short = Tensor::from_vec(inv_freq_short, (1, inv_freq_len_short), dev)?;
498        let t_short = Tensor::arange(0u32, max_seq_len as u32, dev)?
499            .to_dtype(DType::F32)?
500            .reshape((max_seq_len, 1))?;
501        let freqs_short = t_short.matmul(&inv_freq_short)?;
502        let sin_short = (freqs_short.sin()?.to_dtype(dtype)? * short_mscale)?;
503        let cos_short = (freqs_short.cos()?.to_dtype(dtype)? * short_mscale)?;
504
505        // Long cos/sin
506        let inv_freq_long: Vec<_> = (0..dim)
507            .step_by(2)
508            .enumerate()
509            .map(|(k, i)| {
510                1f32 / (long_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32
511            })
512            .collect();
513        let inv_freq_len_long = inv_freq_long.len();
514        let inv_freq_long = Tensor::from_vec(inv_freq_long, (1, inv_freq_len_long), dev)?;
515        let t_long = Tensor::arange(0u32, max_seq_len as u32, dev)?
516            .to_dtype(DType::F32)?
517            .reshape((max_seq_len, 1))?;
518        let freqs_long = t_long.matmul(&inv_freq_long)?;
519        let sin_long = (freqs_long.sin()?.to_dtype(dtype)? * long_mscale)?;
520        let cos_long = (freqs_long.cos()?.to_dtype(dtype)? * long_mscale)?;
521        Ok(Self {
522            short_cos: cos_short,
523            short_sin: sin_short,
524            long_cos: Some(cos_long),
525            long_sin: Some(sin_long),
526            original_max_position_embeddings: cfg.original_max_position_embeddings,
527        })
528    }
529
530    pub fn new(dtype: DType, cfg: impl Into<PhiRopeConfig>, dev: &Device) -> Result<Self> {
531        let cfg: PhiRopeConfig = cfg.into();
532
533        match &cfg.rope_scaling {
534            Some(PhiRopeScalingConfig::Classic {
535                short_factor,
536                long_factor,
537                scaling_type,
538            }) => {
539                Self::new_classic_scaled(short_factor, long_factor, scaling_type, &cfg, dtype, dev)
540            }
541
542            Some(PhiRopeScalingConfig::Scaled {
543                short_factor,
544                long_factor,
545                scaling_type,
546                long_mscale,
547                short_mscale,
548            }) => Self::new_scaled(
549                short_factor,
550                long_factor,
551                scaling_type,
552                *long_mscale,
553                *short_mscale,
554                &cfg,
555                dtype,
556                dev,
557            ),
558
559            None => Self::new_unscaled(&cfg, dtype, dev),
560        }
561    }
562
563    /// Returns (sin, cos) taking into account LongRope
564    fn get_long_or_short_sin_cos(&self, position_ids: &[usize]) -> (&Tensor, &Tensor) {
565        if self.long_cos.is_none() {
566            return (&self.short_sin, &self.short_cos);
567        }
568        let seq_len = position_ids.iter().max().unwrap() + 1;
569        if seq_len > self.original_max_position_embeddings {
570            (
571                self.long_sin.as_ref().unwrap(),
572                self.long_cos.as_ref().unwrap(),
573            )
574        } else {
575            (&self.short_sin, &self.short_cos)
576        }
577    }
578
579    pub fn forward(
580        &self,
581        q: &Tensor,
582        k: &Tensor,
583        seqlen_offsets: &[usize],
584        position_ids: &[usize],
585    ) -> Result<(Tensor, Tensor)> {
586        let (sin, cos) = self.get_long_or_short_sin_cos(position_ids);
587        let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
588
589        let rot_dim = cos.dim(D::Minus1)? * 2;
590
591        // Case for Phi 3 / Phi 4 mini
592        if rot_dim != q.dim(D::Minus1)? {
593            let rot_dim = cos.dim(D::Minus1)? * 2;
594            let q_rot = q.narrow(D::Minus1, 0, rot_dim)?;
595            let q_pass = q.narrow(D::Minus1, rot_dim, q.dim(D::Minus1)? - rot_dim)?;
596            let k_rot = k.narrow(D::Minus1, 0, rot_dim)?;
597            let k_pass = k.narrow(D::Minus1, rot_dim, k.dim(D::Minus1)? - rot_dim)?;
598
599            let (q_rot, k_rot) = if seqlen_offsets.len() == 1 {
600                let cos = cos.narrow(0, seqlen_offsets[0], seq_len)?;
601                let sin = sin.narrow(0, seqlen_offsets[0], seq_len)?;
602                let q_embed = candle_nn::rotary_emb::rope(&q_rot.contiguous()?, &cos, &sin)?;
603                let k_embed = candle_nn::rotary_emb::rope(&k_rot.contiguous()?, &cos, &sin)?;
604                (q_embed, k_embed)
605            } else {
606                let mut q_embeds = Vec::new();
607                let mut k_embeds = Vec::new();
608                for (i, offset) in seqlen_offsets.iter().enumerate() {
609                    let cos = cos.narrow(0, *offset, seq_len)?;
610                    let sin = sin.narrow(0, *offset, seq_len)?;
611                    let q_embed = candle_nn::rotary_emb::rope(
612                        &q_rot.i(i)?.unsqueeze(0)?.contiguous()?,
613                        &cos,
614                        &sin,
615                    )?;
616                    let k_embed = candle_nn::rotary_emb::rope(
617                        &k_rot.i(i)?.unsqueeze(0)?.contiguous()?,
618                        &cos,
619                        &sin,
620                    )?;
621                    q_embeds.push(q_embed);
622                    k_embeds.push(k_embed);
623                }
624                let q_rot = Tensor::cat(&q_embeds, 0)?;
625                let k_rot = Tensor::cat(&k_embeds, 0)?;
626                (q_rot, k_rot)
627            };
628
629            Ok((
630                Tensor::cat(&[q_rot, q_pass], D::Minus1)?.contiguous()?,
631                Tensor::cat(&[k_rot, k_pass], D::Minus1)?.contiguous()?,
632            ))
633        } else if seqlen_offsets.len() == 1 {
634            let cos = cos.narrow(0, seqlen_offsets[0], seq_len)?;
635            let sin = sin.narrow(0, seqlen_offsets[0], seq_len)?;
636            let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
637            let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
638            Ok((q_embed, k_embed))
639        } else {
640            let mut q_embeds = Vec::new();
641            let mut k_embeds = Vec::new();
642            for (i, offset) in seqlen_offsets.iter().enumerate() {
643                let cos = cos.narrow(0, *offset, seq_len)?;
644                let sin = sin.narrow(0, *offset, seq_len)?;
645                let q_embed =
646                    candle_nn::rotary_emb::rope(&q.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
647                let k_embed =
648                    candle_nn::rotary_emb::rope(&k.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
649                q_embeds.push(q_embed);
650                k_embeds.push(k_embed);
651            }
652            Ok((Tensor::cat(&q_embeds, 0)?, Tensor::cat(&k_embeds, 0)?))
653        }
654    }
655}
656
657/// RoPE for Llama3
658#[derive(Debug, Clone)]
659pub struct Llama3RotaryEmbedding(RotaryEmbedding);
660
661#[derive(Debug, Clone, Deserialize, Serialize, Default)]
662pub enum Llama3RopeType {
663    #[serde(rename = "llama3")]
664    Llama3,
665    #[serde(rename = "linear")]
666    Linear,
667    #[default]
668    #[serde(rename = "default")]
669    Default,
670}
671
672#[derive(Debug, Clone, Deserialize, Serialize, Default)]
673pub struct Llama3RopeConfig {
674    pub factor: f32,
675    pub low_freq_factor: Option<f32>,
676    pub high_freq_factor: Option<f32>,
677    pub original_max_position_embeddings: Option<usize>,
678    pub rope_type: Llama3RopeType,
679}
680
681fn calculate_default_inv_freq(cfg: &llama::Config) -> Vec<f32> {
682    let head_dim = cfg.hidden_size / cfg.num_attention_heads;
683    (0..head_dim)
684        .step_by(2)
685        .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
686        .collect()
687}
688
689fn calculate_default_inv_freq_llama4(cfg: &llama4::TextConfig) -> Vec<f32> {
690    let head_dim = cfg.hidden_size / cfg.num_attention_heads;
691    (0..head_dim)
692        .step_by(2)
693        .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
694        .collect()
695}
696
697// https://github.com/huggingface/transformers/blob/1392a6867f40a55dfabaf306745c67627598b1af/src/transformers/modeling_rope_utils.py#L298
698impl Llama3RotaryEmbedding {
699    pub fn new_llama3(
700        dtype: DType,
701        cfg: &llama::Config,
702        dev: &Device,
703        is_gpt_neox: bool,
704    ) -> Result<Self> {
705        match &cfg.rope_scaling {
706            None
707            | Some(Llama3RopeConfig {
708                rope_type: Llama3RopeType::Default,
709                ..
710            }) => Ok(Self(RotaryEmbedding::new(
711                cfg.rope_theta,
712                cfg.hidden_size / cfg.num_attention_heads,
713                cfg.max_position_embeddings,
714                dev,
715                is_gpt_neox,
716                dtype,
717            )?)),
718            Some(Llama3RopeConfig {
719                rope_type: Llama3RopeType::Llama3,
720                factor,
721                low_freq_factor,
722                high_freq_factor,
723                original_max_position_embeddings,
724            }) => {
725                let low_freq_factor = low_freq_factor.context("low_freq_factor is required")?;
726                let high_freq_factor = high_freq_factor.context("high_freq_factor is required")?;
727                let original_max_position_embeddings = original_max_position_embeddings
728                    .context("original_max_position_embeddings is required")?;
729
730                let low_freq_wavelen = original_max_position_embeddings as f32 / low_freq_factor;
731                let high_freq_wavelen = original_max_position_embeddings as f32 / high_freq_factor;
732
733                let inv_freq = calculate_default_inv_freq(cfg)
734                    .into_iter()
735                    .map(|freq| {
736                        let wavelen = 2. * PI / freq;
737                        if wavelen < high_freq_wavelen {
738                            freq
739                        } else if wavelen > low_freq_wavelen {
740                            freq / *factor
741                        } else {
742                            let smooth = (original_max_position_embeddings as f32 / wavelen
743                                - low_freq_factor)
744                                / (high_freq_factor - low_freq_factor);
745                            (1. - smooth) * freq / *factor + smooth * freq
746                        }
747                    })
748                    .collect::<Vec<_>>();
749                let inv_freq_len = inv_freq.len();
750                let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
751                let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
752                    .to_dtype(DType::F32)?
753                    .reshape((cfg.max_position_embeddings, 1))?;
754                let freqs = t.matmul(&inv_freq)?;
755                let sin = freqs.sin()?.to_dtype(dtype)?;
756                let cos = freqs.cos()?.to_dtype(dtype)?;
757                Ok(Self(RotaryEmbedding {
758                    sin,
759                    cos,
760                    is_gpt_neox,
761                }))
762            }
763            Some(Llama3RopeConfig {
764                rope_type: Llama3RopeType::Linear,
765                factor,
766                ..
767            }) => {
768                let inv_freq_vec = calculate_default_inv_freq(cfg)
769                    .into_iter()
770                    .map(|freq| freq / *factor)
771                    .collect::<Vec<_>>();
772                let inv_freq_len = inv_freq_vec.len();
773                let inv_freq = Tensor::from_vec(inv_freq_vec, (1, inv_freq_len), dev)?;
774                let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
775                    .to_dtype(DType::F32)?
776                    .reshape((cfg.max_position_embeddings, 1))?;
777                let freqs = t.matmul(&inv_freq)?;
778                let sin = freqs.sin()?.to_dtype(dtype)?;
779                let cos = freqs.cos()?.to_dtype(dtype)?;
780                Ok(Self(RotaryEmbedding {
781                    sin,
782                    cos,
783                    is_gpt_neox,
784                }))
785            }
786        }
787    }
788
789    pub fn new_llama4(
790        dtype: DType,
791        cfg: &llama4::TextConfig,
792        dev: &Device,
793        is_gpt_neox: bool,
794    ) -> Result<Self> {
795        match &cfg.rope_scaling {
796            None
797            | Some(Llama3RopeConfig {
798                rope_type: Llama3RopeType::Default,
799                ..
800            }) => Ok(Self(RotaryEmbedding::new(
801                cfg.rope_theta,
802                cfg.hidden_size / cfg.num_attention_heads,
803                cfg.max_position_embeddings,
804                dev,
805                is_gpt_neox,
806                dtype,
807            )?)),
808            Some(Llama3RopeConfig {
809                rope_type: Llama3RopeType::Llama3,
810                factor,
811                low_freq_factor,
812                high_freq_factor,
813                original_max_position_embeddings,
814            }) => {
815                let low_freq_factor = low_freq_factor.context("low_freq_factor is required")?;
816                let high_freq_factor = high_freq_factor.context("high_freq_factor is required")?;
817                let original_max_position_embeddings = original_max_position_embeddings
818                    .context("original_max_position_embeddings is required")?;
819
820                let low_freq_wavelen = original_max_position_embeddings as f32 / low_freq_factor;
821                let high_freq_wavelen = original_max_position_embeddings as f32 / high_freq_factor;
822
823                let inv_freq = calculate_default_inv_freq_llama4(cfg)
824                    .into_iter()
825                    .map(|freq| {
826                        let wavelen = 2. * PI / freq;
827                        if wavelen < high_freq_wavelen {
828                            freq
829                        } else if wavelen > low_freq_wavelen {
830                            freq / *factor
831                        } else {
832                            let smooth = (original_max_position_embeddings as f32 / wavelen
833                                - low_freq_factor)
834                                / (high_freq_factor - low_freq_factor);
835                            (1. - smooth) * freq / *factor + smooth * freq
836                        }
837                    })
838                    .collect::<Vec<_>>();
839                let inv_freq_len = inv_freq.len();
840                let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
841                let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
842                    .to_dtype(DType::F32)?
843                    .reshape((cfg.max_position_embeddings, 1))?;
844                let freqs = t.matmul(&inv_freq)?;
845                let sin = freqs.sin()?.to_dtype(dtype)?;
846                let cos = freqs.cos()?.to_dtype(dtype)?;
847                Ok(Self(RotaryEmbedding {
848                    sin,
849                    cos,
850                    is_gpt_neox,
851                }))
852            }
853            Some(Llama3RopeConfig {
854                rope_type: Llama3RopeType::Linear,
855                factor,
856                ..
857            }) => {
858                let inv_freq_vec = calculate_default_inv_freq_llama4(cfg)
859                    .into_iter()
860                    .map(|freq| freq / *factor)
861                    .collect::<Vec<_>>();
862                let inv_freq_len = inv_freq_vec.len();
863                let inv_freq = Tensor::from_vec(inv_freq_vec, (1, inv_freq_len), dev)?;
864                let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
865                    .to_dtype(DType::F32)?
866                    .reshape((cfg.max_position_embeddings, 1))?;
867                let freqs = t.matmul(&inv_freq)?;
868                let sin = freqs.sin()?.to_dtype(dtype)?;
869                let cos = freqs.cos()?.to_dtype(dtype)?;
870                Ok(Self(RotaryEmbedding {
871                    sin,
872                    cos,
873                    is_gpt_neox,
874                }))
875            }
876        }
877    }
878
879    pub fn new_mllama3(
880        dtype: DType,
881        cfg: &MLlamaTextConfig,
882        dev: &Device,
883        is_gpt_neox: bool,
884    ) -> Result<Self> {
885        match &cfg.rope_scaling {
886            None
887            | Some(MLlamaRopeScaling {
888                rope_type: MLlamaRopeType::Default,
889                ..
890            }) => Ok(Self(RotaryEmbedding::new(
891                cfg.rope_theta,
892                cfg.hidden_size / cfg.num_attention_heads,
893                cfg.max_position_embeddings,
894                dev,
895                is_gpt_neox,
896                dtype,
897            )?)),
898            Some(MLlamaRopeScaling {
899                rope_type: MLlamaRopeType::Llama3,
900                original_max_position_embeddings,
901                factor,
902                attention_factor: _,
903                beta_fast: _,
904                beta_slow: _,
905                short_factor: _,
906                long_factor: _,
907                low_freq_factor,
908                high_freq_factor,
909            }) => {
910                let factor = factor.context("MLlama Llama3 RoPE needs `factor` parameter.")?;
911                let low_freq_factor = low_freq_factor
912                    .context("MLlama Llama3 RoPE needs `low_freq_factor` parameter.")?;
913                let high_freq_factor = high_freq_factor
914                    .context("MLlama Llama3 RoPE needs `high_freq_factor` parameter.")?;
915
916                let low_freq_wavelen = *original_max_position_embeddings as f32 / low_freq_factor;
917                let high_freq_wavelen = *original_max_position_embeddings as f32 / high_freq_factor;
918
919                let head_dim = cfg.hidden_size / cfg.num_attention_heads;
920
921                let inv_freq = (0..head_dim)
922                    .step_by(2)
923                    .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
924                    .map(|freq| {
925                        let wavelen = 2. * PI / freq;
926                        if wavelen < high_freq_wavelen {
927                            freq
928                        } else if wavelen > low_freq_wavelen {
929                            freq / factor
930                        } else {
931                            let smooth = (*original_max_position_embeddings as f32 / wavelen
932                                - low_freq_factor)
933                                / (high_freq_factor - low_freq_factor);
934                            (1. - smooth) * freq / factor + smooth * freq
935                        }
936                    })
937                    .collect::<Vec<_>>();
938                let inv_freq_len = inv_freq.len();
939                let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
940
941                let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
942                    .to_dtype(DType::F32)?
943                    .reshape((cfg.max_position_embeddings, 1))?;
944                let freqs = t.matmul(&inv_freq)?;
945                let sin = freqs.sin()?.to_dtype(dtype)?;
946                let cos = freqs.cos()?.to_dtype(dtype)?;
947                Ok(Self(RotaryEmbedding {
948                    sin,
949                    cos,
950                    is_gpt_neox,
951                }))
952            }
953            Some(MLlamaRopeScaling {
954                rope_type: other, ..
955            }) => {
956                candle_core::bail!(
957                    "MLlama doesn't support any other RoPE type than `llama3`, got {other:?}"
958                )
959            }
960        }
961    }
962
963    pub fn forward(
964        &self,
965        q: &Tensor,
966        k: &Tensor,
967        seqlen_offsets: &[usize],
968    ) -> Result<(Tensor, Tensor)> {
969        self.0.forward(q, k, seqlen_offsets)
970    }
971}
972
973// https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L107
974#[derive(Debug, Clone)]
975pub struct Qwen2VLRotaryEmbedding {
976    inv_freq: Tensor,
977    mrope_section: Vec<usize>,
978}
979
980impl Qwen2VLRotaryEmbedding {
981    pub fn new(
982        base: f32,
983        head_dim: usize,
984        device: &Device,
985        mrope_section: Vec<usize>,
986    ) -> Result<Self> {
987        let inv_freq: Vec<_> = (0..head_dim)
988            .step_by(2)
989            .map(|i| 1f32 / base.powf(i as f32 / head_dim as f32))
990            .collect();
991        let inv_freq_len = inv_freq.len();
992        let inv_freq = Tensor::from_vec(inv_freq, (inv_freq_len,), device)?.to_dtype(DType::F32)?;
993        Ok(Self {
994            inv_freq,
995            mrope_section,
996        })
997    }
998
999    /// (cos, sin)
1000    pub fn compute_cos_sin(&self, position_ids: &Tensor, dtype: DType) -> Result<(Tensor, Tensor)> {
1001        let inv_freq_expanded =
1002            self.inv_freq
1003                .reshape((1, 1, (), 1))?
1004                .repeat((3, position_ids.dim(1)?, 1, 1))?;
1005        let position_ids_expanded = position_ids.unsqueeze(2)?;
1006        let freqs = inv_freq_expanded
1007            .matmul(&position_ids_expanded.to_dtype(inv_freq_expanded.dtype())?)?
1008            .transpose(2, 3)?;
1009        let cos = freqs.cos()?;
1010        let sin = freqs.sin()?;
1011
1012        let cos = Tensor::cat(
1013            &cos.split(&self.mrope_section, D::Minus1)?
1014                .into_iter()
1015                .enumerate()
1016                .map(|(i, m)| m.i(i % 3))
1017                .collect::<Result<Vec<_>>>()?,
1018            D::Minus1,
1019        )?
1020        .squeeze(0)?
1021        .to_dtype(dtype)?
1022        .contiguous()?;
1023        let sin = Tensor::cat(
1024            &sin.split(&self.mrope_section, D::Minus1)?
1025                .into_iter()
1026                .enumerate()
1027                .map(|(i, m)| m.i(i % 3))
1028                .collect::<Result<Vec<_>>>()?,
1029            D::Minus1,
1030        )?
1031        .squeeze(0)?
1032        .to_dtype(dtype)?
1033        .contiguous()?;
1034
1035        Ok((cos, sin))
1036    }
1037
1038    // https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L203
1039    pub fn forward(
1040        &self,
1041        (cos, sin): &(Tensor, Tensor),
1042        q: &mut Tensor,
1043        k: &mut Tensor,
1044    ) -> Result<()> {
1045        *q = candle_nn::rotary_emb::rope(&q.contiguous()?, cos, sin)?;
1046        *k = candle_nn::rotary_emb::rope(&k.contiguous()?, cos, sin)?;
1047        Ok(())
1048    }
1049}
1050
1051#[derive(Debug, Clone)]
1052pub struct Qwen2_5VLRotaryEmbedding {
1053    inv_freq: Tensor,
1054    mrope_section: Vec<usize>,
1055}
1056
1057impl Qwen2_5VLRotaryEmbedding {
1058    pub fn new(
1059        base: f32,
1060        head_dim: usize,
1061        device: &Device,
1062        mrope_section: Vec<usize>,
1063    ) -> Result<Self> {
1064        let inv_freq: Vec<_> = (0..head_dim)
1065            .step_by(2)
1066            .map(|i| 1f32 / base.powf(i as f32 / head_dim as f32))
1067            .collect();
1068        let inv_freq_len = inv_freq.len();
1069        let inv_freq = Tensor::from_vec(inv_freq, (inv_freq_len,), device)?.to_dtype(DType::F32)?;
1070        Ok(Self {
1071            inv_freq,
1072            mrope_section,
1073        })
1074    }
1075
1076    /// (cos, sin)
1077    pub fn compute_cos_sin(&self, position_ids: &Tensor, dtype: DType) -> Result<(Tensor, Tensor)> {
1078        let inv_freq_expanded =
1079            self.inv_freq
1080                .reshape((1, 1, (), 1))?
1081                .repeat((3, position_ids.dim(1)?, 1, 1))?;
1082        let position_ids_expanded = position_ids.unsqueeze(2)?;
1083        let freqs = inv_freq_expanded
1084            .matmul(&position_ids_expanded.to_dtype(inv_freq_expanded.dtype())?)?
1085            .transpose(2, 3)?;
1086        let cos = freqs.cos()?;
1087        let sin = freqs.sin()?;
1088
1089        let cos = Tensor::cat(
1090            &cos.split(&self.mrope_section, D::Minus1)?
1091                .into_iter()
1092                .enumerate()
1093                .map(|(i, m)| m.i(i % 3))
1094                .collect::<Result<Vec<_>>>()?,
1095            D::Minus1,
1096        )?
1097        .squeeze(0)?
1098        .to_dtype(dtype)?
1099        .contiguous()?;
1100        let sin = Tensor::cat(
1101            &sin.split(&self.mrope_section, D::Minus1)?
1102                .into_iter()
1103                .enumerate()
1104                .map(|(i, m)| m.i(i % 3))
1105                .collect::<Result<Vec<_>>>()?,
1106            D::Minus1,
1107        )?
1108        .squeeze(0)?
1109        .to_dtype(dtype)?
1110        .contiguous()?;
1111
1112        Ok((cos, sin))
1113    }
1114
1115    pub fn forward(
1116        &self,
1117        (cos, sin): &(Tensor, Tensor),
1118        q: &mut Tensor,
1119        k: &mut Tensor,
1120    ) -> Result<()> {
1121        *q = candle_nn::rotary_emb::rope(&q.contiguous()?, cos, sin)?;
1122        *k = candle_nn::rotary_emb::rope(&k.contiguous()?, cos, sin)?;
1123        Ok(())
1124    }
1125}
1126
1127#[derive(Debug, Clone)]
1128pub struct DeepSeekV2RotaryEmbedding {
1129    sin: Tensor,
1130    cos: Tensor,
1131}
1132
1133#[derive(Debug, Clone, Deserialize, Serialize)]
1134#[serde(untagged)]
1135pub enum DeepSeekV2RopeScaling {
1136    Yarn {
1137        original_max_position_embeddings: usize,
1138        beta_fast: f32,
1139        beta_slow: f32,
1140        mscale: f32,
1141        mscale_all_dim: f32,
1142        factor: f32,
1143        #[serde(rename = "type")]
1144        scaling_type: ScaledRopeType,
1145    },
1146    LinearOrDynamic {
1147        #[serde(rename = "type")]
1148        scaling_type: ScaledRopeType,
1149        factor: f64,
1150    },
1151}
1152
1153pub struct DeepSeekV2RopeConfig {
1154    pub rope_scaling: Option<DeepSeekV2RopeScaling>,
1155    pub max_position_embeddings: usize,
1156    pub rope_theta: f32,
1157    pub qk_rope_head_dim: usize,
1158}
1159
1160impl DeepSeekV2RotaryEmbedding {
1161    fn new_unscaled(cfg: &DeepSeekV2RopeConfig, dtype: DType, dev: &Device) -> Result<Self> {
1162        let max_seq_len = cfg.max_position_embeddings;
1163        let dim = cfg.qk_rope_head_dim;
1164
1165        let inv_freq: Vec<_> = (0..dim)
1166            .step_by(2)
1167            .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / dim as f32))
1168            .collect();
1169        let inv_freq_len = inv_freq.len();
1170        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
1171        let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
1172            .to_dtype(DType::F32)?
1173            .reshape((max_seq_len, 1))?;
1174        let freqs = t.matmul(&inv_freq)?;
1175
1176        let sin = freqs.sin()?.to_dtype(dtype)?;
1177        let cos = freqs.cos()?.to_dtype(dtype)?;
1178
1179        Ok(Self { sin, cos })
1180    }
1181
1182    fn yarn_find_correction_dim(
1183        num_rot: f32,
1184        dim: usize,
1185        base: f32,
1186        max_position_embeddings: usize,
1187    ) -> f32 {
1188        (dim as f32 * (max_position_embeddings as f32 / (num_rot * 2. * PI)).ln())
1189            / (2. * base.ln())
1190    }
1191
1192    fn yarn_find_correction_range(
1193        low_rot: f32,
1194        high_rot: f32,
1195        dim: usize,
1196        base: f32,
1197        max_position_embeddings: usize,
1198    ) -> (f32, f32) {
1199        let low =
1200            Self::yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings).floor();
1201        let high =
1202            Self::yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings).ceil();
1203        (low.max(0.), high.min(dim as f32 - 1.))
1204    }
1205
1206    fn yarn_linear_ramp_mask(min: f32, mut max: f32, dim: usize, dev: &Device) -> Result<Tensor> {
1207        if min == max {
1208            // https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite/blob/604d5664dddd88a0433dbae533b7fe9472482de0/modeling_deepseek.py#L255
1209            max += 0.001;
1210        }
1211        let linear_func =
1212            ((Tensor::arange(0f32, dim as f32, dev)? - min as f64)? / (max as f64 - min as f64))?;
1213        linear_func.clamp(0., 1)
1214    }
1215
1216    pub(crate) fn yarn_get_mscale(scale: f32, mscale: f32) -> f32 {
1217        if scale <= 1. {
1218            return 1.;
1219        }
1220        0.1 * mscale * scale.ln() + 1.
1221    }
1222
1223    #[allow(clippy::too_many_arguments)]
1224    fn new_yarn(
1225        cfg: &DeepSeekV2RopeConfig,
1226        dtype: DType,
1227        dev: &Device,
1228        original_max_position_embeddings: usize,
1229        beta_fast: f32,
1230        beta_slow: f32,
1231        factor: f32,
1232        mscale: f32,
1233        mscale_all_dim: f32,
1234    ) -> Result<Self> {
1235        let freq_extra: Vec<_> = (0..cfg.qk_rope_head_dim)
1236            .step_by(2)
1237            .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / cfg.qk_rope_head_dim as f32))
1238            .collect();
1239        let freq_extra_len = freq_extra.len();
1240        let freq_extra = Tensor::from_vec(freq_extra, freq_extra_len, dev)?;
1241        let freq_inter: Vec<_> = (0..cfg.qk_rope_head_dim)
1242            .step_by(2)
1243            .map(|i| 1f32 / (factor * cfg.rope_theta.powf(i as f32 / cfg.qk_rope_head_dim as f32)))
1244            .collect();
1245        let freq_inter_len = freq_inter.len();
1246        let freq_inter = Tensor::from_vec(freq_inter, (1, freq_inter_len), dev)?;
1247
1248        let (low, high) = Self::yarn_find_correction_range(
1249            beta_fast,
1250            beta_slow,
1251            cfg.qk_rope_head_dim,
1252            cfg.rope_theta,
1253            original_max_position_embeddings,
1254        );
1255        let inv_freq_mask =
1256            (1. - Self::yarn_linear_ramp_mask(low, high, cfg.qk_rope_head_dim / 2, dev)?)?;
1257        let inv_freq = freq_inter
1258            .broadcast_mul(&(1. - &inv_freq_mask)?)?
1259            .broadcast_add(&freq_extra.broadcast_mul(&inv_freq_mask)?)?;
1260
1261        let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
1262            .to_dtype(DType::F32)?
1263            .reshape((cfg.max_position_embeddings, 1))?;
1264        let freqs = t.matmul(&inv_freq)?;
1265
1266        let mscale =
1267            Self::yarn_get_mscale(factor, mscale) / Self::yarn_get_mscale(factor, mscale_all_dim);
1268        let sin = (freqs.sin()? * mscale as f64)?.to_dtype(dtype)?;
1269        let cos = (freqs.cos()? * mscale as f64)?.to_dtype(dtype)?;
1270
1271        Ok(Self { sin, cos })
1272    }
1273
1274    pub fn new(cfg: &DeepSeekV2RopeConfig, dtype: DType, dev: &Device) -> Result<Self> {
1275        match &cfg.rope_scaling {
1276            Some(DeepSeekV2RopeScaling::LinearOrDynamic {
1277                scaling_type: _,
1278                factor: _,
1279            }) => candle_core::bail!("linear and dynamic rope are not implemented yet!"),
1280            Some(DeepSeekV2RopeScaling::Yarn {
1281                original_max_position_embeddings,
1282                beta_fast,
1283                beta_slow,
1284                factor,
1285                mscale,
1286                mscale_all_dim,
1287                scaling_type: _,
1288            }) => Self::new_yarn(
1289                cfg,
1290                dtype,
1291                dev,
1292                *original_max_position_embeddings,
1293                *beta_fast,
1294                *beta_slow,
1295                *factor,
1296                *mscale,
1297                *mscale_all_dim,
1298            ),
1299            None => Self::new_unscaled(cfg, dtype, dev),
1300        }
1301    }
1302
1303    pub fn forward(
1304        &self,
1305        q: &Tensor,
1306        k: &Tensor,
1307        seqlen_offsets: &[usize],
1308    ) -> Result<(Tensor, Tensor)> {
1309        let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
1310
1311        if seqlen_offsets.len() == 1 {
1312            let cos = self.cos.narrow(0, seqlen_offsets[0], seq_len)?;
1313            let sin = self.sin.narrow(0, seqlen_offsets[0], seq_len)?;
1314            let q_embed = candle_nn::rotary_emb::rope_i(&q.contiguous()?, &cos, &sin)?;
1315            let k_embed = candle_nn::rotary_emb::rope_i(&k.contiguous()?, &cos, &sin)?;
1316            Ok((q_embed, k_embed))
1317        } else {
1318            let mut q_embeds = Vec::new();
1319            let mut k_embeds = Vec::new();
1320            for (i, offset) in seqlen_offsets.iter().enumerate() {
1321                let cos = self.cos.narrow(0, *offset, seq_len)?;
1322                let sin = self.sin.narrow(0, *offset, seq_len)?;
1323                let q_embed = candle_nn::rotary_emb::rope_i(
1324                    &q.i(i)?.unsqueeze(0)?.contiguous()?,
1325                    &cos,
1326                    &sin,
1327                )?;
1328                let k_embed = candle_nn::rotary_emb::rope_i(
1329                    &k.i(i)?.unsqueeze(0)?.contiguous()?,
1330                    &cos,
1331                    &sin,
1332                )?;
1333                q_embeds.push(q_embed);
1334                k_embeds.push(k_embed);
1335            }
1336            Ok((Tensor::cat(&q_embeds, 0)?, Tensor::cat(&k_embeds, 0)?))
1337        }
1338    }
1339}
1340
1341#[derive(Debug, Clone)]
1342pub struct Phi4MMRotaryEmbedding {
1343    short_sin: Tensor,
1344    short_cos: Tensor,
1345    long_cos: Option<Tensor>,
1346    long_sin: Option<Tensor>,
1347    original_max_position_embeddings: usize,
1348}
1349
1350#[derive(Debug, Clone, Default, Deserialize, Serialize)]
1351#[serde(rename_all = "lowercase")]
1352pub enum Phi4MMScaledRopeType {
1353    #[serde(alias = "longrope")]
1354    LongRope,
1355    #[default]
1356    Default,
1357}
1358
1359#[derive(Debug, Clone, Deserialize, Serialize)]
1360pub struct Phi4MMRopeScalingConfig {
1361    short_factor: Option<Vec<f64>>,
1362    long_factor: Option<Vec<f64>>,
1363    #[serde(rename = "type")]
1364    scaling_type: Phi4MMScaledRopeType,
1365}
1366
1367impl Phi4MMRotaryEmbedding {
1368    fn new_unscaled(cfg: &Phi4MMConfig, dtype: DType, dev: &Device) -> Result<Self> {
1369        let max_seq_len = cfg.max_position_embeddings;
1370        let dim = (cfg.head_dim() as f64 * cfg.partial_rotary_factor) as usize;
1371
1372        let inv_freq: Vec<_> = (0..dim)
1373            .step_by(2)
1374            .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
1375            .collect();
1376        let inv_freq_len = inv_freq.len();
1377        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
1378        let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
1379            .to_dtype(DType::F32)?
1380            .reshape((max_seq_len, 1))?;
1381        let freqs = t.matmul(&inv_freq)?;
1382        let sin = freqs.sin()?.to_dtype(dtype)?;
1383        let cos = freqs.cos()?.to_dtype(dtype)?;
1384        Ok(Self {
1385            short_cos: cos,
1386            short_sin: sin,
1387            long_cos: None,
1388            long_sin: None,
1389            original_max_position_embeddings: cfg.original_max_position_embeddings,
1390        })
1391    }
1392
1393    #[allow(clippy::too_many_arguments)]
1394    fn new_longrope(
1395        short_factor: &[f64],
1396        long_factor: &[f64],
1397        cfg: &Phi4MMConfig,
1398        dtype: DType,
1399        dev: &Device,
1400    ) -> Result<Self> {
1401        let max_seq_len = cfg.max_position_embeddings;
1402        let dim = (cfg.head_dim() as f64 * cfg.partial_rotary_factor) as usize;
1403
1404        // Calculate scale
1405        let scale =
1406            cfg.max_position_embeddings as f64 / cfg.original_max_position_embeddings as f64;
1407        let scaling_factor = if scale <= 1.0 {
1408            1.0
1409        } else {
1410            (1.0 + scale.ln() / (cfg.original_max_position_embeddings as f64).ln()).sqrt()
1411        };
1412
1413        // Short cos/sin
1414        let inv_freq_short: Vec<_> = (0..dim)
1415            .step_by(2)
1416            .enumerate()
1417            .map(|(k, i)| {
1418                1f32 / (short_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32
1419            })
1420            .collect();
1421        let inv_freq_len_short = inv_freq_short.len();
1422        let inv_freq_short = Tensor::from_vec(inv_freq_short, (1, inv_freq_len_short), dev)?;
1423        let t_short = Tensor::arange(0u32, max_seq_len as u32, dev)?
1424            .to_dtype(DType::F32)?
1425            .reshape((max_seq_len, 1))?;
1426        let freqs_short = t_short.matmul(&inv_freq_short)?;
1427        let sin_short = (freqs_short.sin()?.to_dtype(dtype)? * scaling_factor)?;
1428        let cos_short = (freqs_short.cos()?.to_dtype(dtype)? * scaling_factor)?;
1429
1430        // Long cos/sin
1431        let inv_freq_long: Vec<_> = (0..dim)
1432            .step_by(2)
1433            .enumerate()
1434            .map(|(k, i)| {
1435                1f32 / (long_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32
1436            })
1437            .collect();
1438        let inv_freq_len_long = inv_freq_long.len();
1439        let inv_freq_long = Tensor::from_vec(inv_freq_long, (1, inv_freq_len_long), dev)?;
1440        let t_long = Tensor::arange(0u32, max_seq_len as u32, dev)?
1441            .to_dtype(DType::F32)?
1442            .reshape((max_seq_len, 1))?;
1443        let freqs_long = t_long.matmul(&inv_freq_long)?;
1444        let sin_long = (freqs_long.sin()?.to_dtype(dtype)? * scaling_factor)?;
1445        let cos_long = (freqs_long.cos()?.to_dtype(dtype)? * scaling_factor)?;
1446
1447        Ok(Self {
1448            short_cos: cos_short,
1449            short_sin: sin_short,
1450            long_cos: Some(cos_long),
1451            long_sin: Some(sin_long),
1452            original_max_position_embeddings: cfg.original_max_position_embeddings,
1453        })
1454    }
1455
1456    pub fn new(dtype: DType, cfg: &Phi4MMConfig, dev: &Device) -> Result<Self> {
1457        match &cfg.rope_scaling {
1458            Some(Phi4MMRopeScalingConfig {
1459                scaling_type: Phi4MMScaledRopeType::LongRope,
1460                short_factor: Some(short_factor),
1461                long_factor: Some(long_factor),
1462            }) => Self::new_longrope(short_factor, long_factor, cfg, dtype, dev),
1463
1464            _ => Self::new_unscaled(cfg, dtype, dev),
1465        }
1466    }
1467
1468    /// Returns (sin, cos) taking into account LongRope
1469    fn get_long_or_short_sin_cos(&self, position_ids: &[usize]) -> (&Tensor, &Tensor) {
1470        if self.long_cos.is_none() {
1471            return (&self.short_sin, &self.short_cos);
1472        }
1473        let seq_len = position_ids.iter().max().unwrap() + 1;
1474        if seq_len > self.original_max_position_embeddings {
1475            (
1476                self.long_sin.as_ref().unwrap(),
1477                self.long_cos.as_ref().unwrap(),
1478            )
1479        } else {
1480            (&self.short_sin, &self.short_cos)
1481        }
1482    }
1483
1484    pub fn forward(
1485        &self,
1486        q: &Tensor,
1487        k: &Tensor,
1488        seqlen_offsets: &[usize],
1489        position_ids: &[usize],
1490    ) -> Result<(Tensor, Tensor)> {
1491        let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
1492        let (sin, cos) = self.get_long_or_short_sin_cos(position_ids);
1493
1494        let rot_dim = cos.dim(D::Minus1)? * 2;
1495        let q_rot = q.narrow(D::Minus1, 0, rot_dim)?;
1496        let q_pass = q.narrow(D::Minus1, rot_dim, q.dim(D::Minus1)? - rot_dim)?;
1497        let k_rot = k.narrow(D::Minus1, 0, rot_dim)?;
1498        let k_pass = k.narrow(D::Minus1, rot_dim, k.dim(D::Minus1)? - rot_dim)?;
1499
1500        let (q_rot, k_rot) = if seqlen_offsets.len() == 1 {
1501            let cos = cos.narrow(0, seqlen_offsets[0], seq_len)?;
1502            let sin = sin.narrow(0, seqlen_offsets[0], seq_len)?;
1503            let q_embed = candle_nn::rotary_emb::rope(&q_rot.contiguous()?, &cos, &sin)?;
1504            let k_embed = candle_nn::rotary_emb::rope(&k_rot.contiguous()?, &cos, &sin)?;
1505            (q_embed, k_embed)
1506        } else {
1507            let mut q_embeds = Vec::new();
1508            let mut k_embeds = Vec::new();
1509            for (i, offset) in seqlen_offsets.iter().enumerate() {
1510                let cos = cos.narrow(0, *offset, seq_len)?;
1511                let sin = sin.narrow(0, *offset, seq_len)?;
1512                let q_embed = candle_nn::rotary_emb::rope(
1513                    &q_rot.i(i)?.unsqueeze(0)?.contiguous()?,
1514                    &cos,
1515                    &sin,
1516                )?;
1517                let k_embed = candle_nn::rotary_emb::rope(
1518                    &k_rot.i(i)?.unsqueeze(0)?.contiguous()?,
1519                    &cos,
1520                    &sin,
1521                )?;
1522                q_embeds.push(q_embed);
1523                k_embeds.push(k_embed);
1524            }
1525            let q_rot = Tensor::cat(&q_embeds, 0)?;
1526            let k_rot = Tensor::cat(&k_embeds, 0)?;
1527            (q_rot, k_rot)
1528        };
1529
1530        Ok((
1531            Tensor::cat(&[q_rot, q_pass], D::Minus1)?.contiguous()?,
1532            Tensor::cat(&[k_rot, k_pass], D::Minus1)?.contiguous()?,
1533        ))
1534    }
1535}
1536
1537#[derive(Debug, Clone)]
1538pub struct Gemma3RotaryEmbedding(RotaryEmbedding);
1539
1540#[derive(Debug, Clone, Deserialize, Serialize)]
1541#[serde(rename_all = "lowercase")]
1542pub enum Gemma3ScaledRopeType {
1543    #[serde(alias = "linear")]
1544    Linear,
1545}
1546
1547#[derive(Debug, Clone, Deserialize, Serialize)]
1548pub struct Gemma3RopeScalingConfig {
1549    factor: f64,
1550    rope_type: Gemma3ScaledRopeType,
1551}
1552
1553impl Gemma3RotaryEmbedding {
1554    fn new_linear(
1555        cfg: &Gemma3TextConfig,
1556        factor: f64,
1557        is_gpt_neox: bool,
1558        dtype: DType,
1559        dev: &Device,
1560    ) -> Result<Self> {
1561        let max_seq_len = cfg.max_position_embeddings;
1562        let dim = cfg.head_dim;
1563
1564        let inv_freq: Vec<_> = (0..dim)
1565            .step_by(2)
1566            .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
1567            .collect();
1568        let inv_freq_len = inv_freq.len();
1569        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
1570        let inv_freq = (inv_freq / factor)?;
1571
1572        let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
1573            .to_dtype(DType::F32)?
1574            .reshape((max_seq_len, 1))?;
1575        let freqs = t.matmul(&inv_freq)?;
1576        let sin = freqs.sin()?.to_dtype(dtype)?;
1577        let cos = freqs.cos()?.to_dtype(dtype)?;
1578        Ok(Self(RotaryEmbedding {
1579            cos,
1580            sin,
1581            is_gpt_neox,
1582        }))
1583    }
1584
1585    pub fn new(
1586        is_gpt_neox: bool,
1587        dtype: DType,
1588        cfg: &Gemma3TextConfig,
1589        dev: &Device,
1590    ) -> Result<Self> {
1591        match &cfg.rope_scaling {
1592            Some(Gemma3RopeScalingConfig {
1593                rope_type: Gemma3ScaledRopeType::Linear,
1594                factor,
1595            }) => Self::new_linear(cfg, *factor, is_gpt_neox, dtype, dev),
1596
1597            _ => Self::new_linear(cfg, 1.0, is_gpt_neox, dtype, dev),
1598        }
1599    }
1600
1601    pub fn forward(
1602        &self,
1603        q: &Tensor,
1604        k: &Tensor,
1605        seqlen_offsets: &[usize],
1606    ) -> Result<(Tensor, Tensor)> {
1607        self.0.forward(q, k, seqlen_offsets)
1608    }
1609}
1610
1611pub struct DiaRotaryEmbedding {
1612    timescale: Tensor,
1613    dtype: DType,
1614}
1615
1616impl DiaRotaryEmbedding {
1617    pub fn new(
1618        min_timescale: f32,
1619        max_timescale: f32,
1620        head_dim: usize,
1621        device: &Device,
1622        dtype: DType,
1623    ) -> Result<Self> {
1624        assert_eq!(head_dim % 2, 0);
1625        let half_embedding_dim = head_dim / 2;
1626
1627        let fraction = (0..half_embedding_dim).map(|i| 2f32 * i as f32 / head_dim as f32);
1628        let timescale = fraction
1629            .into_iter()
1630            .map(|x| min_timescale * (max_timescale / min_timescale).powf(x))
1631            .collect::<Vec<_>>();
1632
1633        let timescale_len = timescale.len();
1634        let timescale = Tensor::from_vec(timescale, timescale_len, device)?;
1635
1636        Ok(Self { timescale, dtype })
1637    }
1638
1639    pub fn forward(&self, xs: &Tensor, positions: &Tensor) -> Result<Tensor> {
1640        let freqs = positions
1641            .unsqueeze(D::Minus1)?
1642            .unsqueeze(D::Minus1)?
1643            .broadcast_div(&self.timescale)?;
1644
1645        let sin = freqs.sin()?.to_dtype(self.dtype)?;
1646        let cos = freqs.cos()?.to_dtype(self.dtype)?;
1647
1648        let split = xs.chunk(2, D::Minus1)?;
1649        let first_half = &split[0];
1650        let second_half = &split[1];
1651
1652        let first_part = (first_half.broadcast_mul(&cos)? - second_half.broadcast_mul(&sin)?)?;
1653        let second_part = (second_half.broadcast_mul(&cos)? + first_half.broadcast_mul(&sin)?)?;
1654
1655        Tensor::cat(&[first_part, second_part], D::Minus1)
1656    }
1657}
1658#[derive(Debug, Clone)]
1659pub struct QLinear {
1660    inner: QMatMul,
1661    bias: Option<Tensor>,
1662    dtype: DType,
1663}
1664
1665impl QLinear {
1666    pub fn new<R: std::io::Read + std::io::Seek>(
1667        ct: &mut Content<'_, R>,
1668        name: &str,
1669        device: &Device,
1670    ) -> Result<Self> {
1671        let w = ct.tensor(&format!("{name}.weight"), device)?;
1672        let b = ct.tensor(&format!("{name}.bias"), device)?;
1673        let inner = QMatMul::from_qtensor(w)?;
1674        let bias = b.dequantize(device)?;
1675        Ok(Self {
1676            inner,
1677            bias: Some(bias),
1678            dtype: DType::F32,
1679        })
1680    }
1681
1682    pub fn from_linear(linear: Linear) -> Self {
1683        Self {
1684            inner: QMatMul::Tensor(linear.weight().clone()),
1685            bias: linear.bias().cloned(),
1686            dtype: linear.weight().dtype(),
1687        }
1688    }
1689
1690    pub fn from_parts(w: Tensor, b: Option<Tensor>) -> Self {
1691        let dtype = w.dtype();
1692        Self {
1693            inner: QMatMul::Tensor(w),
1694            bias: b,
1695            dtype,
1696        }
1697    }
1698
1699    pub fn from_qparts(w: QTensor, b: Option<Tensor>) -> Self {
1700        if let Some(ref b) = b {
1701            assert_eq!(b.dtype(), DType::F32);
1702        }
1703        Self {
1704            inner: QMatMul::QTensor(Arc::new(w)),
1705            bias: b,
1706            dtype: DType::F32,
1707        }
1708    }
1709
1710    pub fn from_old_and_qmatmul(inner: QMatMul, old: &Self) -> Self {
1711        Self {
1712            inner,
1713            bias: old.bias.clone(),
1714            dtype: old.dtype,
1715        }
1716    }
1717
1718    pub fn inner(&mut self) -> &mut QMatMul {
1719        &mut self.inner
1720    }
1721
1722    pub fn inner_ref(&self) -> &QMatMul {
1723        &self.inner
1724    }
1725
1726    pub fn is_quant(&self) -> bool {
1727        matches!(self.inner, QMatMul::QTensor(_))
1728    }
1729
1730    pub fn bias(&self) -> Option<&Tensor> {
1731        self.bias.as_ref()
1732    }
1733
1734    pub fn bias_mut(&mut self) -> Option<&mut Tensor> {
1735        self.bias.as_mut()
1736    }
1737}
1738
1739impl Module for QLinear {
1740    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
1741        let xs = if self.is_quant() {
1742            xs.to_dtype(DType::F32)?
1743        } else {
1744            xs.clone()
1745        };
1746        if let Some(bias) = &self.bias {
1747            self.inner
1748                .forward(&xs)?
1749                .broadcast_add(bias)?
1750                .to_dtype(self.dtype)
1751        } else {
1752            self.inner.forward(&xs)?.to_dtype(self.dtype)
1753        }
1754    }
1755}
1756
1757#[derive(Debug, Clone)]
1758pub struct RotaryEmbedding {
1759    cos: Tensor,
1760    sin: Tensor,
1761    is_gpt_neox: bool,
1762}
1763
1764impl RotaryEmbedding {
1765    pub fn new(
1766        base: f32,
1767        head_dim: usize,
1768        max_position_embeddings: usize,
1769        device: &Device,
1770        is_gpt_neox: bool,
1771        dtype: DType,
1772    ) -> Result<Self> {
1773        let inv_freq: Vec<_> = (0..head_dim)
1774            .step_by(2)
1775            .map(|i| 1f32 / base.powf(i as f32 / head_dim as f32))
1776            .collect();
1777        let inv_freq_len = inv_freq.len();
1778        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), device)?;
1779        let t = Tensor::arange(0u32, max_position_embeddings as u32, device)?
1780            .to_dtype(DType::F32)?
1781            .reshape((max_position_embeddings, 1))?;
1782        let freqs = t.matmul(&inv_freq)?;
1783        let sin = freqs.sin()?.to_dtype(dtype)?;
1784        let cos = freqs.cos()?.to_dtype(dtype)?;
1785
1786        Ok(Self {
1787            cos,
1788            sin,
1789            is_gpt_neox,
1790        })
1791    }
1792
1793    pub fn new_partial(
1794        base: f32,
1795        rot_dim: usize,
1796        max_position_embeddings: usize,
1797        device: &Device,
1798        is_gpt_neox: bool,
1799        dtype: DType,
1800    ) -> Result<Self> {
1801        let inv_freq: Vec<_> = (0..rot_dim)
1802            .step_by(2)
1803            .map(|i| 1f32 / base.powf(i as f32 / rot_dim as f32))
1804            .collect();
1805        let inv_freq_len = inv_freq.len();
1806        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), device)?;
1807        let t = Tensor::arange(0u32, max_position_embeddings as u32, device)?
1808            .to_dtype(DType::F32)?
1809            .reshape((max_position_embeddings, 1))?;
1810        let freqs = t.matmul(&inv_freq)?;
1811        let sin = freqs.sin()?.to_dtype(dtype)?;
1812        let cos = freqs.cos()?.to_dtype(dtype)?;
1813
1814        Ok(Self {
1815            cos,
1816            sin,
1817            is_gpt_neox,
1818        })
1819    }
1820
1821    pub fn forward(
1822        &self,
1823        q: &Tensor,
1824        k: &Tensor,
1825        seqlen_offsets: &[usize],
1826    ) -> Result<(Tensor, Tensor)> {
1827        let (b_sz, qh, seq_len, n_embd) = q.dims4()?;
1828        let (_b_sz, kh, _seq_len, __n_embd) = k.dims4()?;
1829
1830        let rope = if self.is_gpt_neox {
1831            candle_nn::rotary_emb::rope
1832        } else {
1833            candle_nn::rotary_emb::rope_i
1834        };
1835
1836        if cfg!(feature = "cuda") && qh == kh {
1837            let (cos, sin) = if seqlen_offsets.len() == 1 {
1838                (
1839                    self.cos.narrow(0, seqlen_offsets[0], seq_len)?,
1840                    self.sin.narrow(0, seqlen_offsets[0], seq_len)?,
1841                )
1842            } else {
1843                let mut cos_s = Vec::new();
1844                let mut sin_s = Vec::new();
1845                for offset in seqlen_offsets {
1846                    cos_s.push(self.cos.narrow(0, *offset, seq_len)?);
1847                    sin_s.push(self.sin.narrow(0, *offset, seq_len)?);
1848                }
1849                (Tensor::cat(&cos_s, 0)?, Tensor::cat(&sin_s, 0)?)
1850            };
1851
1852            let q_embed = q.transpose(1, 2)?.flatten(0, 1)?;
1853            let k_embed = k.transpose(1, 2)?.flatten(0, 1)?;
1854            mistralrs_quant::rotary::apply_rotary_inplace(
1855                &q_embed,
1856                &k_embed,
1857                &cos,
1858                &sin,
1859                self.is_gpt_neox,
1860            )?;
1861            let mut q = q_embed
1862                .reshape((b_sz, seq_len, qh, n_embd))?
1863                .transpose(1, 2)?;
1864            let mut k = k_embed
1865                .reshape((b_sz, seq_len, kh, n_embd))?
1866                .transpose(1, 2)?;
1867            if !(cfg!(feature = "flash-attn") || cfg!(feature = "flash-attn-v3")) {
1868                q = q.contiguous()?;
1869                k = k.contiguous()?;
1870            }
1871            Ok((q, k))
1872        } else if seqlen_offsets.len() == 1 {
1873            let cos = self.cos.narrow(0, seqlen_offsets[0], seq_len)?;
1874            let sin = self.sin.narrow(0, seqlen_offsets[0], seq_len)?;
1875            let q_embed = rope(&q.contiguous()?, &cos, &sin)?;
1876            let k_embed = rope(&k.contiguous()?, &cos, &sin)?;
1877            Ok((q_embed, k_embed))
1878        } else {
1879            let mut q_embeds = Vec::new();
1880            let mut k_embeds = Vec::new();
1881            for (i, offset) in seqlen_offsets.iter().enumerate() {
1882                let cos = self.cos.narrow(0, *offset, seq_len)?;
1883                let sin = self.sin.narrow(0, *offset, seq_len)?;
1884                let q_embed = rope(&q.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
1885                let k_embed = rope(&k.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
1886                q_embeds.push(q_embed);
1887                k_embeds.push(k_embed);
1888            }
1889            Ok((Tensor::cat(&q_embeds, 0)?, Tensor::cat(&k_embeds, 0)?))
1890        }
1891    }
1892}
1893
1894#[derive(Debug, Clone, Copy, PartialEq, Deserialize, Serialize, Default)]
1895#[serde(rename_all = "lowercase")]
1896pub enum Activation {
1897    #[default]
1898    #[serde(alias = "gelu")]
1899    Gelu,
1900    #[serde(alias = "gelu_new")]
1901    NewGelu,
1902    Relu,
1903    Relu2,
1904    Relu6,
1905    Silu,
1906    Sigmoid,
1907    HardSigmoid,
1908    Swiglu,
1909    Swish,
1910    HardSwish,
1911    Elu(f64),
1912    LeakyRelu(f64),
1913    #[serde(alias = "gelu_pytorch_tanh")]
1914    GeluPytorchTanh,
1915    QuickGelu,
1916}
1917
1918impl Module for Activation {
1919    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
1920        match self {
1921            Self::Gelu => xs.gelu_erf(),
1922            // https://github.com/huggingface/transformers/blob/12f043eaeaabfef6f6efea411d98e6f6d3c094b7/src/transformers/activations.py#L49-L78
1923            Self::NewGelu => xs.gelu(),
1924            Self::Relu => xs.relu(),
1925            Self::Relu2 => xs.relu()?.sqr(),
1926            Self::Relu6 => xs.clamp(0f32, 6f32),
1927            Self::Silu => xs.silu(),
1928            Self::Sigmoid => candle_nn::ops::sigmoid(xs),
1929            Self::HardSigmoid => candle_nn::ops::hard_sigmoid(xs),
1930            Self::Swiglu => candle_nn::ops::swiglu(xs),
1931            Self::Swish => xs * candle_nn::ops::sigmoid(xs)?,
1932            Self::HardSwish => xs * candle_nn::ops::hard_sigmoid(xs)?,
1933            &Self::Elu(alpha) => xs.elu(alpha),
1934            &Self::LeakyRelu(negative_slope) => candle_nn::ops::leaky_relu(xs, negative_slope),
1935            Self::GeluPytorchTanh => xs.gelu(),
1936            Self::QuickGelu => xs * candle_nn::ops::sigmoid(&(xs * 1.702f64)?),
1937        }
1938    }
1939}
1940
1941impl TryInto<candle_nn::Activation> for Activation {
1942    type Error = candle_core::Error;
1943
1944    fn try_into(self) -> Result<candle_nn::Activation> {
1945        match self {
1946            Self::Gelu => Ok(candle_nn::Activation::Gelu),
1947            Self::Relu => Ok(candle_nn::Activation::Relu),
1948            Self::Silu => Ok(candle_nn::Activation::Silu),
1949            Self::NewGelu => Ok(candle_nn::Activation::NewGelu),
1950            Self::Relu2 => Ok(candle_nn::Activation::Relu2),
1951            Self::Relu6 => Ok(candle_nn::Activation::Relu6),
1952            Self::Sigmoid => Ok(candle_nn::Activation::Sigmoid),
1953            Self::HardSigmoid => Ok(candle_nn::Activation::HardSigmoid),
1954            Self::Swiglu => Ok(candle_nn::Activation::Swiglu),
1955            Self::Swish => Ok(candle_nn::Activation::Swish),
1956            Self::HardSwish => Ok(candle_nn::Activation::HardSwish),
1957            Self::Elu(x) => Ok(candle_nn::Activation::Elu(x)),
1958            Self::LeakyRelu(x) => Ok(candle_nn::Activation::LeakyRelu(x)),
1959            Self::GeluPytorchTanh => Ok(candle_nn::Activation::GeluPytorchTanh),
1960            Self::QuickGelu => candle_core::bail!("No mapping to candle_nn for QuickGelu"),
1961        }
1962    }
1963}
1964
1965#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1966pub struct Conv3dConfig {
1967    pub padding: usize,
1968    pub stride: usize,
1969    pub dilation: usize,
1970    pub groups: usize,
1971}
1972
1973impl Default for Conv3dConfig {
1974    fn default() -> Self {
1975        Self {
1976            padding: 0,
1977            stride: 1,
1978            dilation: 1,
1979            groups: 1,
1980        }
1981    }
1982}
1983
1984pub struct Conv3dNoBias {
1985    conv2d_1: Conv2d,
1986    conv2d_2: Conv2d,
1987}
1988
1989impl Conv3dNoBias {
1990    pub fn new(
1991        in_channels: usize,
1992        out_channels: usize,
1993        kernel_sizes: [usize; 3],
1994        cfg: Conv3dConfig,
1995        vb: ShardedVarBuilder,
1996    ) -> Result<Self> {
1997        let ws = vb.get(
1998            (
1999                out_channels,
2000                in_channels / cfg.groups,
2001                kernel_sizes[0],
2002                kernel_sizes[1],
2003                kernel_sizes[2],
2004            ),
2005            "weight",
2006        )?;
2007
2008        // Split on temporal dimension
2009        // https://github.com/pytorch/pytorch/issues/139066
2010
2011        let w1 = ws.i((.., .., 0, .., ..))?;
2012        let w2 = ws.i((.., .., 1, .., ..))?;
2013
2014        let cfg = Conv2dConfig {
2015            padding: cfg.padding,
2016            stride: cfg.stride,
2017            dilation: cfg.dilation,
2018            groups: cfg.groups,
2019        };
2020
2021        Ok(Self {
2022            conv2d_1: Conv2d::new(w1.contiguous()?, None, cfg),
2023            conv2d_2: Conv2d::new(w2.contiguous()?, None, cfg),
2024        })
2025    }
2026}
2027
2028impl Module for Conv3dNoBias {
2029    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2030        let xs1 = xs.i((.., .., 0, .., ..))?;
2031        let xs2 = xs.i((.., .., 1, .., ..))?;
2032
2033        (self.conv2d_1.forward(&xs1)? + self.conv2d_2.forward(&xs2)?)?.unsqueeze(2)
2034    }
2035}
2036
2037pub trait TensorInfExtend {
2038    fn is_inf(&self) -> Result<Self>
2039    where
2040        Self: Sized;
2041    fn any(&self) -> Result<bool>;
2042}
2043
2044impl TensorInfExtend for Tensor {
2045    fn is_inf(&self) -> Result<Self> {
2046        self.broadcast_eq(&Tensor::new(f32::INFINITY, self.device())?.to_dtype(self.dtype())?)
2047    }
2048
2049    fn any(&self) -> Result<bool> {
2050        let sum = self.sum_all()?;
2051        match self.dtype() {
2052            DType::U8 => Ok(sum.to_scalar::<u8>()? == 0),
2053            DType::U32 => Ok(sum.to_scalar::<u32>()? == 0),
2054            DType::I16 => Ok(sum.to_scalar::<i16>()? == 0),
2055            DType::I32 => Ok(sum.to_scalar::<i32>()? == 0),
2056            DType::I64 => Ok(sum.to_scalar::<i64>()? == 0),
2057            DType::F16 => Ok(sum.to_scalar::<half::f16>()? == half::f16::from_f32_const(0.)),
2058            DType::BF16 => Ok(sum.to_scalar::<half::bf16>()? == half::bf16::from_f32_const(0.)),
2059            DType::F32 => Ok(sum.to_scalar::<f32>()? == 0.),
2060            DType::F64 => Ok(sum.to_scalar::<f64>()? == 0.),
2061            DType::F8E4M3 => Ok(sum.to_scalar::<F8E4M3>()? == F8E4M3::ZERO),
2062        }
2063    }
2064}
2065
2066pub fn clamp_for_f16(xs: &Tensor) -> Result<Tensor> {
2067    let mut max = match xs.dtype() {
2068        DType::U8 => u8::MAX as f32 - 1000.,
2069        DType::U32 => u32::MAX as f32 - 1000.,
2070        DType::I16 => i16::MAX as f32 - 1000.,
2071        DType::I32 => i32::MAX as f32 - 1000.,
2072        DType::I64 => i64::MAX as f32 - 1000.,
2073        DType::F16 => half::f16::MAX.to_f32_const() - 1000.,
2074        DType::BF16 => half::bf16::MAX.to_f32_const() - 1000.,
2075        DType::F32 => f32::MAX - 1000.,
2076        DType::F64 => f64::MAX as f32 - 1000.,
2077        DType::F8E4M3 => F8E4M3::MAX.to_f32() - 1000.,
2078    };
2079    if xs.is_inf()?.any()? {
2080        max -= 1000.;
2081    }
2082    xs.clamp(-max, max)
2083}
2084
2085pub struct FloatInfo {
2086    /// Minimum representable value.
2087    pub min: f64,
2088    /// Maximum representable value.
2089    pub max: f64,
2090    /// The difference between 1.0 and the next smallest representable float larger than 1.0.
2091    pub eps: f64,
2092    pub dtype: DType,
2093}
2094
2095pub trait GetFloatInfo {
2096    fn finfo(&self) -> Result<FloatInfo>;
2097}
2098
2099impl GetFloatInfo for DType {
2100    fn finfo(&self) -> Result<FloatInfo> {
2101        let finfo = match self {
2102            Self::BF16 => FloatInfo {
2103                min: bf16::MIN.to_f64(),
2104                max: bf16::MAX.to_f64(),
2105                eps: bf16::EPSILON.to_f64(),
2106                dtype: DType::BF16,
2107            },
2108            Self::F16 => FloatInfo {
2109                min: f16::MIN.to_f64(),
2110                max: f16::MAX.to_f64(),
2111                eps: f16::EPSILON.to_f64(),
2112                dtype: DType::F16,
2113            },
2114            Self::F32 => FloatInfo {
2115                min: f32::MIN as f64,
2116                max: f32::MAX as f64,
2117                eps: f32::EPSILON as f64,
2118                dtype: DType::F32,
2119            },
2120            Self::F64 => FloatInfo {
2121                min: f64::MIN,
2122                max: f64::MAX,
2123                eps: f64::EPSILON,
2124                dtype: DType::F64,
2125            },
2126            Self::F8E4M3 => FloatInfo {
2127                min: F8E4M3::MIN.to_f64(),
2128                max: F8E4M3::MAX.to_f64(),
2129                eps: F8E4M3::EPSILON.to_f64(),
2130                dtype: DType::F8E4M3,
2131            },
2132            other => {
2133                candle_core::bail!("Expected a float type for `GetFloatInfo`, got {other:?}");
2134            }
2135        };
2136        Ok(finfo)
2137    }
2138}
2139
2140#[derive(Clone)]
2141pub struct Mlp {
2142    pub gate: Arc<dyn QuantMethod>,
2143    pub up: Arc<dyn QuantMethod>,
2144    pub down: Arc<dyn QuantMethod>,
2145    act: Activation,
2146    params: Vec<usize>,
2147}
2148
2149impl Mlp {
2150    pub fn new(
2151        vb: ShardedVarBuilder,
2152        hidden_size: usize,
2153        intermediate_size: usize,
2154        quantization_config: &Option<QuantizedConfig>,
2155        hidden_act: Activation,
2156        comm: &Arc<mistralrs_quant::Comm>,
2157    ) -> Result<Self> {
2158        Ok(Self {
2159            gate: ColumnParallelLayer::new(
2160                hidden_size,
2161                intermediate_size,
2162                quantization_config,
2163                false,
2164                comm,
2165                vb.pp("gate_proj"),
2166            )?,
2167            up: ColumnParallelLayer::new(
2168                hidden_size,
2169                intermediate_size,
2170                quantization_config,
2171                false,
2172                comm,
2173                vb.pp("up_proj"),
2174            )?,
2175            down: RowParallelLayer::new(
2176                intermediate_size,
2177                hidden_size,
2178                quantization_config,
2179                false,
2180                comm,
2181                vb.pp("down_proj"),
2182            )?,
2183            act: hidden_act,
2184            params: vec![hidden_size, intermediate_size],
2185        })
2186    }
2187
2188    pub fn new_merged(
2189        vb: ShardedVarBuilder,
2190        hidden_size: usize,
2191        intermediate_size: usize,
2192        chunks: usize,
2193        quantization_config: &Option<QuantizedConfig>,
2194        hidden_act: Activation,
2195        comm: &Arc<mistralrs_quant::Comm>,
2196    ) -> Result<Self> {
2197        assert!(chunks == 2, "Only gate_up_proj merge is supported!");
2198        let gate_up_projs = ColumnParallelLayer::new_merged(
2199            hidden_size,
2200            intermediate_size * 2,
2201            2,
2202            quantization_config,
2203            false,
2204            comm,
2205            vb.pp("gate_up_proj"),
2206        )?;
2207
2208        Ok(Self {
2209            gate: gate_up_projs[0].to_owned(),
2210            up: gate_up_projs[1].to_owned(),
2211            down: RowParallelLayer::new(
2212                intermediate_size,
2213                hidden_size,
2214                quantization_config,
2215                false,
2216                comm,
2217                vb.pp("down_proj"),
2218            )?,
2219            act: hidden_act,
2220            params: vec![hidden_size, intermediate_size],
2221        })
2222    }
2223
2224    pub fn replicate(
2225        params: &[usize],
2226        vb: ShardedVarBuilder,
2227        act: Activation,
2228        comm: &Arc<mistralrs_quant::Comm>,
2229    ) -> Result<Self> {
2230        Self::new(vb, params[0], params[1], &None, act, comm)
2231    }
2232
2233    pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2234        let original_dtype = xs.dtype();
2235        let mut xs = xs.clone();
2236        if let Some(t) = self.gate.quantized_act_type() {
2237            xs = xs.to_dtype(t)?;
2238        }
2239        let lhs = self.gate.forward(&xs)?;
2240        let rhs = self.up.forward(&xs)?;
2241        let mut res = self.down.forward(&candle_nn::ops::mul_and_act(
2242            &lhs,
2243            &rhs,
2244            self.act.try_into()?,
2245        )?)?;
2246        if self.gate.quantized_act_type().is_some() {
2247            res = res.to_dtype(original_dtype)?;
2248        }
2249        Ok(res)
2250    }
2251}
2252
2253impl AnyMoeTrainableLayer for Mlp {}
2254
2255impl MlpLayer for Mlp {
2256    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2257        let original_dtype = xs.dtype();
2258        let mut xs = xs.clone();
2259        if let Some(t) = self.gate.quantized_act_type() {
2260            xs = xs.to_dtype(t)?;
2261        }
2262        let lhs = MatMul.qmethod_matmul(&xs, &*self.gate)?;
2263        let rhs = MatMul.qmethod_matmul(&xs, &*self.up)?;
2264        let mut res = if matches!(
2265            self.act,
2266            Activation::Gelu | Activation::Silu | Activation::Relu
2267        ) {
2268            MatMul.qmethod_matmul(
2269                &candle_nn::ops::mul_and_act(&lhs, &rhs, self.act.try_into()?)?,
2270                &*self.down,
2271            )?
2272        } else {
2273            MatMul.qmethod_matmul(&(&lhs.apply(&self.act)? * &rhs)?, &*self.down)?
2274        };
2275        if self.gate.quantized_act_type().is_some() {
2276            res = res.to_dtype(original_dtype)?;
2277        }
2278        Ok(res)
2279    }
2280    fn get_isq_layers(&mut self) -> Vec<&mut Arc<dyn QuantMethod>> {
2281        vec![&mut self.gate, &mut self.up, &mut self.down]
2282    }
2283    fn clone(&self) -> Box<dyn MlpLayer> {
2284        Box::new(Clone::clone(self))
2285    }
2286    fn get_params(&self) -> &[usize] {
2287        &self.params
2288    }
2289    fn hidden_act(&self) -> Activation {
2290        self.act
2291    }
2292    // gate, up, down
2293    fn new_added_delta(&self, deltas: Vec<Option<Tensor>>) -> Result<Box<dyn MlpLayer>> {
2294        let gate = if let Some(ref delta) = deltas[0] {
2295            self.gate.add_delta_w(delta)?
2296        } else {
2297            self.gate.clone()
2298        };
2299        let up = if let Some(ref delta) = deltas[1] {
2300            self.up.add_delta_w(delta)?
2301        } else {
2302            self.up.clone()
2303        };
2304        let down = if let Some(ref delta) = deltas[2] {
2305            self.down.add_delta_w(delta)?
2306        } else {
2307            self.down.clone()
2308        };
2309
2310        Ok(Box::new(Self {
2311            gate,
2312            up,
2313            down,
2314            act: self.act,
2315            params: self.params.clone(),
2316        }))
2317    }
2318
2319    fn dtype_device(&self) -> (DType, Device) {
2320        self.gate.dtype_and_device()
2321    }
2322}
2323
2324pub struct AvgPool2d {
2325    kernel_size: usize,
2326    stride: usize,
2327}
2328
2329impl AvgPool2d {
2330    pub fn new(kernel_size: usize, stride: usize) -> Self {
2331        Self {
2332            kernel_size,
2333            stride,
2334        }
2335    }
2336
2337    pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2338        xs.avg_pool2d_with_stride(self.kernel_size, self.stride)
2339    }
2340}
2341
2342/// Applies 2D reflection padding to a tensor of shape (N, C, H, W).
2343///
2344/// The `padding` argument is a 4-tuple (pad_left, pad_right, pad_top, pad_bottom).
2345/// For left padding, it reflects the values from column 1 up to pad_left (in reverse order);
2346/// for right padding, it reflects from the second-to-last column backwards, and similarly for
2347/// vertical (height) padding.
2348pub struct ReflectionPad2d {
2349    padding: (usize, usize, usize, usize),
2350}
2351
2352impl ReflectionPad2d {
2353    pub fn new(padding: (usize, usize, usize, usize)) -> Self {
2354        Self { padding }
2355    }
2356}
2357
2358impl Module for ReflectionPad2d {
2359    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2360        let (pad_left, pad_right, pad_top, pad_bottom) = self.padding;
2361
2362        let (_n, _c, h, w) = xs.dims4()?;
2363
2364        // --- Horizontal Padding (along width, axis = 3) ---
2365        // For left padding, we reflect columns 1..=pad_left (in reverse order).
2366        let left_pad = if pad_left > 0 {
2367            // Create indices: [pad_left, pad_left-1, ..., 1]
2368            let indices: Vec<i64> = (1..=pad_left as i64).rev().collect();
2369            Some(xs.index_select(&Tensor::new(indices, &Device::Cpu)?, 3)?)
2370        } else {
2371            None
2372        };
2373
2374        // For right padding, we reflect from the right side (excluding the last column).
2375        let right_pad = if pad_right > 0 {
2376            // For pad_right == 2, generate indices: [w-2, w-3, ... , w-1-pad_right]
2377            let start = w as i64 - 2;
2378            let indices: Vec<i64> = (0..pad_right as i64).map(|i| start - i).collect();
2379            Some(xs.index_select(&Tensor::new(indices, &Device::Cpu)?, 3)?)
2380        } else {
2381            None
2382        };
2383
2384        // Concatenate horizontally (along width, dim=3)
2385        let x_padded_width = match (left_pad, right_pad) {
2386            (Some(l), Some(r)) => Tensor::cat(&[l, xs.clone(), r], 3)?,
2387            (Some(l), None) => Tensor::cat(&[l, xs.clone()], 3)?,
2388            (None, Some(r)) => Tensor::cat(&[xs.clone(), r], 3)?,
2389            (None, None) => xs.clone(),
2390        };
2391
2392        // --- Vertical Padding (along height, axis = 2) ---
2393        // For top padding, reflect rows 1..=pad_top (in reverse order)
2394        let top_pad = if pad_top > 0 {
2395            let indices: Vec<i64> = (1..=pad_top as i64).rev().collect();
2396            Some(x_padded_width.index_select(&Tensor::new(indices, &Device::Cpu)?, 2)?)
2397        } else {
2398            None
2399        };
2400
2401        // For bottom padding, reflect from the bottom (excluding the last row)
2402        let bottom_pad = if pad_bottom > 0 {
2403            let start = h as i64 - 2;
2404            let indices: Vec<i64> = (0..pad_bottom as i64).map(|i| start - i).collect();
2405            Some(x_padded_width.index_select(&Tensor::new(indices, &Device::Cpu)?, 2)?)
2406        } else {
2407            None
2408        };
2409
2410        // Concatenate vertically (along height, dim=2)
2411        let x_padded = match (top_pad, bottom_pad) {
2412            (Some(t), Some(b)) => Tensor::cat(&[t, x_padded_width, b], 2)?,
2413            (Some(t), None) => Tensor::cat(&[t, x_padded_width], 2)?,
2414            (None, Some(b)) => Tensor::cat(&[x_padded_width, b], 2)?,
2415            (None, None) => x_padded_width,
2416        };
2417
2418        Ok(x_padded)
2419    }
2420}
2421
2422pub struct ScaledEmbedding {
2423    scale: f64,
2424    embedding: Embedding,
2425}
2426
2427impl ScaledEmbedding {
2428    pub fn new(scale: f64, embedding: Embedding) -> Self {
2429        Self { scale, embedding }
2430    }
2431
2432    pub fn embeddings(&self) -> &Tensor {
2433        self.embedding.embeddings()
2434    }
2435}
2436
2437impl Module for ScaledEmbedding {
2438    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2439        xs.apply(&self.embedding)? * self.scale
2440    }
2441}