mistralrs_core/
layers.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{f32::consts::PI, ops::Mul, str::FromStr, sync::Arc};
4
5use candle_core::{
6    quantized::{QMatMul, QTensor},
7    Context, DType, Device, IndexOp, Result, Tensor, D,
8};
9use candle_nn::{
10    Conv2d, Conv2dConfig, Embedding, GroupNorm, LayerNorm, LayerNormConfig, Linear, Module,
11};
12use float8::F8E4M3;
13use half::{bf16, f16};
14use mistralrs_quant::{
15    AfqLayer, ColumnParallelLayer, QuantMethod, QuantizedConfig, RowParallelLayer,
16    ShardedVarBuilder,
17};
18use serde::{Deserialize, Serialize};
19
20pub use crate::attention::Sdpa;
21pub use crate::layers_masker::CausalMasker;
22pub use crate::layers_utils::repeat_kv;
23use crate::{
24    amoe::{AnyMoeTrainableLayer, MlpLayer},
25    gguf::Content,
26    models::llama,
27    ops::SplitOp,
28    vision_models::{
29        gemma3::config::Gemma3TextConfig,
30        llama4,
31        mllama::{MLlamaRopeScaling, MLlamaRopeType, MLlamaTextConfig},
32        phi4::Phi4MMConfig,
33    },
34};
35
36pub use mistralrs_quant::MatMul;
37
38pub fn embedding(
39    in_size: usize,
40    out_size: usize,
41    vb: ShardedVarBuilder,
42    config: &Option<QuantizedConfig>,
43) -> Result<Embedding> {
44    // AFQ quantized applies quantization to the embeddings.
45    let embeddings = if let Some(QuantizedConfig::Afq { .. }) = config {
46        let afq_layer =
47            AfqLayer::afq_linear_b(out_size, in_size, config.as_ref().unwrap(), false, vb)?;
48        afq_layer.dequantize_w()?
49    } else {
50        vb.get_with_hints((in_size, out_size), "weight", Default::default())?
51    };
52    Ok(Embedding::new(embeddings, out_size))
53}
54
55pub fn layer_norm<C: Into<LayerNormConfig>>(
56    size: usize,
57    config: C,
58    vb: ShardedVarBuilder,
59) -> Result<LayerNorm> {
60    let config = config.into();
61    let weight = vb.get(size, "weight")?;
62    if config.affine {
63        let bias = vb.get(size, "bias")?;
64        Ok(LayerNorm::new(weight, bias, config.eps))
65    } else {
66        Ok(LayerNorm::new_no_bias(weight, config.eps))
67    }
68}
69
70pub fn group_norm(
71    num_groups: usize,
72    num_channels: usize,
73    eps: f64,
74    vb: ShardedVarBuilder,
75) -> Result<GroupNorm> {
76    let weight = vb.get(num_channels, "weight")?;
77    let bias = vb.get(num_channels, "bias")?;
78    GroupNorm::new(weight, bias, num_channels, num_groups, eps)
79}
80
81pub fn conv2d(
82    in_channels: usize,
83    out_channels: usize,
84    kernel_size: usize,
85    cfg: Conv2dConfig,
86    vb: ShardedVarBuilder,
87) -> Result<Conv2d> {
88    let ws = vb.get(
89        (
90            out_channels,
91            in_channels / cfg.groups,
92            kernel_size,
93            kernel_size,
94        ),
95        "weight",
96    )?;
97    let bs = vb.get(out_channels, "bias")?;
98    Ok(Conv2d::new(ws, Some(bs), cfg))
99}
100
101pub fn conv2d_no_bias(
102    in_channels: usize,
103    out_channels: usize,
104    kernel_size: usize,
105    cfg: Conv2dConfig,
106    vb: ShardedVarBuilder,
107) -> Result<Conv2d> {
108    let ws = vb.get(
109        (
110            out_channels,
111            in_channels / cfg.groups,
112            kernel_size,
113            kernel_size,
114        ),
115        "weight",
116    )?;
117    Ok(Conv2d::new(ws, None, cfg))
118}
119
120pub fn linear(in_dim: usize, out_dim: usize, vb: ShardedVarBuilder) -> Result<Linear> {
121    let ws = vb.get((out_dim, in_dim), "weight")?;
122    let bs = vb.get(out_dim, "bias")?;
123    Ok(Linear::new(ws, Some(bs)))
124}
125
126pub fn linear_no_bias(in_dim: usize, out_dim: usize, vb: ShardedVarBuilder) -> Result<Linear> {
127    let ws = vb.get((out_dim, in_dim), "weight")?;
128    Ok(Linear::new(ws, None))
129}
130
131pub fn linear_b(
132    in_dim: usize,
133    out_dim: usize,
134    bias: bool,
135    vb: ShardedVarBuilder,
136) -> Result<Linear> {
137    if bias {
138        linear(in_dim, out_dim, vb)
139    } else {
140        linear_no_bias(in_dim, out_dim, vb)
141    }
142}
143
144#[derive(Debug, Clone)]
145pub struct RmsNorm {
146    eps: f64,
147    weight: Tensor,
148}
149
150impl RmsNorm {
151    pub fn new(size: usize, eps: f64, vb: ShardedVarBuilder) -> Result<Self> {
152        let w = vb.get(size, "weight")?;
153        Ok(Self { eps, weight: w })
154    }
155
156    /// Gemma uses weight + 1.0
157    pub fn new_gemma(size: usize, eps: f64, vb: ShardedVarBuilder) -> Result<Self> {
158        let w = vb.get(size, "weight")?;
159        let w = (w + 1.0)?;
160        Ok(Self { eps, weight: w })
161    }
162
163    /// Gemma uses weight + 1.0. Undo for UQFF generation.
164    pub fn undo_gemma(&self) -> Result<Self> {
165        Ok(Self {
166            eps: self.eps,
167            weight: (&self.weight - 1.0)?,
168        })
169    }
170
171    pub fn from_w(w: Tensor, eps: f64) -> Result<Self> {
172        Ok(Self { eps, weight: w })
173    }
174
175    pub fn weight(&self) -> &Tensor {
176        &self.weight
177    }
178}
179
180impl Module for RmsNorm {
181    fn forward(&self, x: &Tensor) -> Result<Tensor> {
182        candle_nn::ops::rms_norm(&x.contiguous()?, &self.weight, self.eps as f32)
183    }
184}
185
186#[derive(Debug, Clone)]
187pub struct F32RmsNorm {
188    w: Tensor,
189    eps: f64,
190}
191
192impl F32RmsNorm {
193    pub fn new(size: usize, eps: f64, vb: ShardedVarBuilder) -> Result<Self> {
194        Ok(Self {
195            w: vb.get((size,), "weight")?,
196            eps,
197        })
198    }
199
200    pub fn weight(&self) -> &Tensor {
201        &self.w
202    }
203}
204
205impl Module for F32RmsNorm {
206    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
207        let initial_type = xs.dtype();
208        let mut xs = xs.to_dtype(DType::F32)?;
209        let var = xs.powf(2.)?.mean_keepdim(D::Minus1)?;
210        xs = xs.broadcast_mul(&(&var + self.eps)?.recip()?.sqrt()?)?;
211        xs.to_dtype(initial_type)?.broadcast_mul(&self.w)
212    }
213}
214
215#[derive(Debug, Clone)]
216pub struct QRmsNorm {
217    eps: f64,
218    weight: Tensor,
219}
220
221impl QRmsNorm {
222    pub fn new(scale: QTensor, eps: f32) -> Result<Self> {
223        let scale = scale.dequantize(&scale.device())?;
224        Ok(Self {
225            eps: eps as f64,
226            weight: scale,
227        })
228    }
229
230    pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
231        candle_nn::ops::rms_norm(&x.contiguous()?, &self.weight, self.eps as f32)
232    }
233}
234
235/// RoPE supporting LongRope
236#[derive(Debug, Clone)]
237pub struct PhiRotaryEmbedding {
238    short_sin: Tensor,
239    short_cos: Tensor,
240    long_cos: Option<Tensor>,
241    long_sin: Option<Tensor>,
242    original_max_position_embeddings: usize,
243}
244
245#[derive(Debug, Clone, Deserialize, Serialize)]
246#[serde(rename_all = "lowercase")]
247pub enum ScaledRopeType {
248    #[serde(alias = "su")]
249    #[serde(alias = "longrope")]
250    Su,
251    #[serde(alias = "yarn")]
252    Yarn,
253    #[serde(alias = "dynamic")]
254    Dynamic,
255    #[serde(alias = "linear")]
256    Linear,
257}
258
259impl FromStr for ScaledRopeType {
260    type Err = candle_core::Error;
261    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
262        match s {
263            "su" | "longrope" => Ok(Self::Su),
264            "yarn" => Ok(Self::Yarn),
265            "linear" => Ok(Self::Linear),
266            "dynamic" => Ok(Self::Dynamic),
267            _ => Err(candle_core::Error::Msg(
268                "Expected either `su` or `yarn` scaled RoPE type.".to_string(),
269            )),
270        }
271    }
272}
273
274#[derive(Debug, Clone, Deserialize, Serialize)]
275#[serde(untagged)]
276pub enum PhiRopeScalingConfig {
277    Classic {
278        short_factor: Vec<f64>,
279        long_factor: Vec<f64>,
280        #[serde(rename = "type")]
281        scaling_type: ScaledRopeType,
282    },
283    Scaled {
284        short_factor: Vec<f64>,
285        long_factor: Vec<f64>,
286        #[serde(rename = "type")]
287        scaling_type: ScaledRopeType,
288        long_mscale: f64,
289        short_mscale: f64,
290    },
291}
292
293pub struct PhiRopeConfig {
294    pub rope_scaling: Option<PhiRopeScalingConfig>,
295    pub max_position_embeddings: usize,
296    pub original_max_position_embeddings: usize,
297    pub rope_theta: f64,
298    pub head_dim: usize,
299    pub partial_rotary_factor: Option<f64>,
300}
301
302impl PhiRotaryEmbedding {
303    fn new_classic_scaled(
304        short_factor: &[f64],
305        long_factor: &[f64],
306        scaling_type: &ScaledRopeType,
307        cfg: &PhiRopeConfig,
308        dtype: DType,
309        dev: &Device,
310    ) -> Result<Self> {
311        let max_seq_len = cfg.max_position_embeddings;
312        let dim = (cfg.head_dim as f64 * cfg.partial_rotary_factor.unwrap_or(1.)) as usize;
313
314        // Calculate scale
315        let scale =
316            cfg.max_position_embeddings as f64 / cfg.original_max_position_embeddings as f64;
317        let scaling_factor = if scale <= 1.0 {
318            1.0
319        } else {
320            match scaling_type {
321                ScaledRopeType::Su => {
322                    (1.0 + scale.ln() / (cfg.original_max_position_embeddings as f64).ln()).sqrt()
323                }
324                ScaledRopeType::Yarn => 0.1 * scale.ln() + 1.0,
325                _ => candle_core::bail!("Expected either `su` or `yarn` RoPE"),
326            }
327        };
328
329        // Calculate inv freqs for short, long
330        let inv_freq_long = (0..dim)
331            .step_by(2)
332            .enumerate()
333            .map(|(k, i)| {
334                (1f64 / (long_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64))) as f32
335            })
336            .collect::<Vec<_>>();
337        let inv_freq_short = (0..dim)
338            .step_by(2)
339            .enumerate()
340            .map(|(k, i)| {
341                (1f64 / (short_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64))) as f32
342            })
343            .collect::<Vec<_>>();
344        let inv_freq_len = inv_freq_long.len();
345
346        let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
347            .to_dtype(DType::F32)?
348            .reshape((max_seq_len, 1))?;
349
350        // Calculate sin,cos for long
351        let inv_freq_long = Tensor::from_vec(inv_freq_long, (1, inv_freq_len), dev)?;
352        let freqs_long = t.matmul(&inv_freq_long)?;
353        let long_sin = freqs_long.sin()?.mul(scaling_factor)?.to_dtype(dtype)?;
354        let long_cos = freqs_long.cos()?.mul(scaling_factor)?.to_dtype(dtype)?;
355
356        // Calculate sin,cos for short
357        let inv_freq_short =
358            Tensor::from_vec(inv_freq_short, (1, inv_freq_len), dev)?.to_dtype(DType::F32)?;
359        let freqs_short = t.matmul(&inv_freq_short)?;
360        let short_sin = freqs_short.sin()?.mul(scaling_factor)?.to_dtype(dtype)?;
361        let short_cos = freqs_short.cos()?.mul(scaling_factor)?.to_dtype(dtype)?;
362
363        Ok(Self {
364            short_cos,
365            short_sin,
366            long_cos: Some(long_cos),
367            long_sin: Some(long_sin),
368            original_max_position_embeddings: cfg.original_max_position_embeddings,
369        })
370    }
371
372    fn new_unscaled(cfg: &PhiRopeConfig, dtype: DType, dev: &Device) -> Result<Self> {
373        let max_seq_len = cfg.max_position_embeddings;
374        let dim = (cfg.head_dim as f64 * cfg.partial_rotary_factor.unwrap_or(1.)) as usize;
375
376        let inv_freq: Vec<_> = (0..dim)
377            .step_by(2)
378            .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
379            .collect();
380        let inv_freq_len = inv_freq.len();
381        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
382        let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
383            .to_dtype(DType::F32)?
384            .reshape((max_seq_len, 1))?;
385        let freqs = t.matmul(&inv_freq)?;
386        let sin = freqs.sin()?.to_dtype(dtype)?;
387        let cos = freqs.cos()?.to_dtype(dtype)?;
388        Ok(Self {
389            short_cos: cos,
390            short_sin: sin,
391            long_cos: None,
392            long_sin: None,
393            original_max_position_embeddings: cfg.original_max_position_embeddings,
394        })
395    }
396
397    #[allow(clippy::too_many_arguments)]
398    fn new_scaled(
399        short_factor: &[f64],
400        long_factor: &[f64],
401        scaling_type: &ScaledRopeType,
402        long_mscale: f64,
403        short_mscale: f64,
404        cfg: &PhiRopeConfig,
405        dtype: DType,
406        dev: &Device,
407    ) -> Result<Self> {
408        let max_seq_len = cfg.max_position_embeddings;
409        let dim = (cfg.head_dim as f64 * cfg.partial_rotary_factor.unwrap_or(1.)) as usize;
410
411        if !matches!(scaling_type, ScaledRopeType::Su) {
412            candle_core::bail!("Scaled Phi3 RoPE (non-classic scaled, with mscales) must have type `su`/`longrope`.");
413        }
414
415        if short_factor.len() != dim / 2 {
416            candle_core::bail!(
417                "Misaligned length {}, expected {} for `su`/`longrope` short rescale factors",
418                short_factor.len(),
419                dim / 2
420            );
421        }
422        if long_factor.len() != dim / 2 {
423            candle_core::bail!(
424                "Misaligned length {}, expected {} for `su`/`longrope` long rescale factors",
425                long_factor.len(),
426                dim / 2
427            );
428        }
429
430        // Short cos/sin
431        let inv_freq_short: Vec<_> = (0..dim)
432            .step_by(2)
433            .enumerate()
434            .map(|(k, i)| {
435                1f32 / (short_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32
436            })
437            .collect();
438        let inv_freq_len_short = inv_freq_short.len();
439        let inv_freq_short = Tensor::from_vec(inv_freq_short, (1, inv_freq_len_short), dev)?;
440        let t_short = Tensor::arange(0u32, max_seq_len as u32, dev)?
441            .to_dtype(DType::F32)?
442            .reshape((max_seq_len, 1))?;
443        let freqs_short = t_short.matmul(&inv_freq_short)?;
444        let sin_short = (freqs_short.sin()?.to_dtype(dtype)? * short_mscale)?;
445        let cos_short = (freqs_short.cos()?.to_dtype(dtype)? * short_mscale)?;
446
447        // Long cos/sin
448        let inv_freq_long: Vec<_> = (0..dim)
449            .step_by(2)
450            .enumerate()
451            .map(|(k, i)| {
452                1f32 / (long_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32
453            })
454            .collect();
455        let inv_freq_len_long = inv_freq_long.len();
456        let inv_freq_long = Tensor::from_vec(inv_freq_long, (1, inv_freq_len_long), dev)?;
457        let t_long = Tensor::arange(0u32, max_seq_len as u32, dev)?
458            .to_dtype(DType::F32)?
459            .reshape((max_seq_len, 1))?;
460        let freqs_long = t_long.matmul(&inv_freq_long)?;
461        let sin_long = (freqs_long.sin()?.to_dtype(dtype)? * long_mscale)?;
462        let cos_long = (freqs_long.cos()?.to_dtype(dtype)? * long_mscale)?;
463        Ok(Self {
464            short_cos: cos_short,
465            short_sin: sin_short,
466            long_cos: Some(cos_long),
467            long_sin: Some(sin_long),
468            original_max_position_embeddings: cfg.original_max_position_embeddings,
469        })
470    }
471
472    pub fn new(dtype: DType, cfg: impl Into<PhiRopeConfig>, dev: &Device) -> Result<Self> {
473        let cfg: PhiRopeConfig = cfg.into();
474
475        match &cfg.rope_scaling {
476            Some(PhiRopeScalingConfig::Classic {
477                short_factor,
478                long_factor,
479                scaling_type,
480            }) => {
481                Self::new_classic_scaled(short_factor, long_factor, scaling_type, &cfg, dtype, dev)
482            }
483
484            Some(PhiRopeScalingConfig::Scaled {
485                short_factor,
486                long_factor,
487                scaling_type,
488                long_mscale,
489                short_mscale,
490            }) => Self::new_scaled(
491                short_factor,
492                long_factor,
493                scaling_type,
494                *long_mscale,
495                *short_mscale,
496                &cfg,
497                dtype,
498                dev,
499            ),
500
501            None => Self::new_unscaled(&cfg, dtype, dev),
502        }
503    }
504
505    /// Returns (sin, cos) taking into account LongRope
506    fn get_long_or_short_sin_cos(&self, position_ids: &[usize]) -> (&Tensor, &Tensor) {
507        if self.long_cos.is_none() {
508            return (&self.short_sin, &self.short_cos);
509        }
510        let seq_len = position_ids.iter().max().unwrap() + 1;
511        if seq_len > self.original_max_position_embeddings {
512            (
513                self.long_sin.as_ref().unwrap(),
514                self.long_cos.as_ref().unwrap(),
515            )
516        } else {
517            (&self.short_sin, &self.short_cos)
518        }
519    }
520
521    pub fn forward(
522        &self,
523        q: &Tensor,
524        k: &Tensor,
525        seqlen_offsets: &[usize],
526        position_ids: &[usize],
527    ) -> Result<(Tensor, Tensor)> {
528        let (sin, cos) = self.get_long_or_short_sin_cos(position_ids);
529        let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
530
531        let rot_dim = cos.dim(D::Minus1)? * 2;
532
533        // Case for Phi 3 / Phi 4 mini
534        if rot_dim != q.dim(D::Minus1)? {
535            let rot_dim = cos.dim(D::Minus1)? * 2;
536            let q_rot = q.narrow(D::Minus1, 0, rot_dim)?;
537            let q_pass = q.narrow(D::Minus1, rot_dim, q.dim(D::Minus1)? - rot_dim)?;
538            let k_rot = k.narrow(D::Minus1, 0, rot_dim)?;
539            let k_pass = k.narrow(D::Minus1, rot_dim, k.dim(D::Minus1)? - rot_dim)?;
540
541            let (q_rot, k_rot) = if seqlen_offsets.len() == 1 {
542                let cos = cos.narrow(0, seqlen_offsets[0], seq_len)?;
543                let sin = sin.narrow(0, seqlen_offsets[0], seq_len)?;
544                let q_embed = candle_nn::rotary_emb::rope(&q_rot.contiguous()?, &cos, &sin)?;
545                let k_embed = candle_nn::rotary_emb::rope(&k_rot.contiguous()?, &cos, &sin)?;
546                (q_embed, k_embed)
547            } else {
548                let mut q_embeds = Vec::new();
549                let mut k_embeds = Vec::new();
550                for (i, offset) in seqlen_offsets.iter().enumerate() {
551                    let cos = cos.narrow(0, *offset, seq_len)?;
552                    let sin = sin.narrow(0, *offset, seq_len)?;
553                    let q_embed = candle_nn::rotary_emb::rope(
554                        &q_rot.i(i)?.unsqueeze(0)?.contiguous()?,
555                        &cos,
556                        &sin,
557                    )?;
558                    let k_embed = candle_nn::rotary_emb::rope(
559                        &k_rot.i(i)?.unsqueeze(0)?.contiguous()?,
560                        &cos,
561                        &sin,
562                    )?;
563                    q_embeds.push(q_embed);
564                    k_embeds.push(k_embed);
565                }
566                let q_rot = Tensor::cat(&q_embeds, 0)?;
567                let k_rot = Tensor::cat(&k_embeds, 0)?;
568                (q_rot, k_rot)
569            };
570
571            Ok((
572                Tensor::cat(&[q_rot, q_pass], D::Minus1)?.contiguous()?,
573                Tensor::cat(&[k_rot, k_pass], D::Minus1)?.contiguous()?,
574            ))
575        } else if seqlen_offsets.len() == 1 {
576            let cos = cos.narrow(0, seqlen_offsets[0], seq_len)?;
577            let sin = sin.narrow(0, seqlen_offsets[0], seq_len)?;
578            let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
579            let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
580            Ok((q_embed, k_embed))
581        } else {
582            let mut q_embeds = Vec::new();
583            let mut k_embeds = Vec::new();
584            for (i, offset) in seqlen_offsets.iter().enumerate() {
585                let cos = cos.narrow(0, *offset, seq_len)?;
586                let sin = sin.narrow(0, *offset, seq_len)?;
587                let q_embed =
588                    candle_nn::rotary_emb::rope(&q.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
589                let k_embed =
590                    candle_nn::rotary_emb::rope(&k.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
591                q_embeds.push(q_embed);
592                k_embeds.push(k_embed);
593            }
594            Ok((Tensor::cat(&q_embeds, 0)?, Tensor::cat(&k_embeds, 0)?))
595        }
596    }
597}
598
599/// RoPE for Llama3
600#[derive(Debug, Clone)]
601pub struct Llama3RotaryEmbedding(RotaryEmbedding);
602
603#[derive(Debug, Clone, Deserialize, Serialize, Default)]
604pub enum Llama3RopeType {
605    #[serde(rename = "llama3")]
606    Llama3,
607    #[default]
608    #[serde(rename = "default")]
609    Default,
610}
611
612#[derive(Debug, Clone, Deserialize, Serialize, Default)]
613pub struct Llama3RopeConfig {
614    pub factor: f32,
615    pub low_freq_factor: f32,
616    pub high_freq_factor: f32,
617    pub original_max_position_embeddings: usize,
618    pub rope_type: Llama3RopeType,
619}
620
621fn calculate_default_inv_freq(cfg: &llama::Config) -> Vec<f32> {
622    let head_dim = cfg.hidden_size / cfg.num_attention_heads;
623    (0..head_dim)
624        .step_by(2)
625        .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
626        .collect()
627}
628
629fn calculate_default_inv_freq_llama4(cfg: &llama4::TextConfig) -> Vec<f32> {
630    let head_dim = cfg.hidden_size / cfg.num_attention_heads;
631    (0..head_dim)
632        .step_by(2)
633        .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
634        .collect()
635}
636
637// https://github.com/huggingface/transformers/blob/1392a6867f40a55dfabaf306745c67627598b1af/src/transformers/modeling_rope_utils.py#L298
638impl Llama3RotaryEmbedding {
639    pub fn new_llama3(
640        dtype: DType,
641        cfg: &llama::Config,
642        dev: &Device,
643        is_gpt_neox: bool,
644    ) -> Result<Self> {
645        match &cfg.rope_scaling {
646            None
647            | Some(Llama3RopeConfig {
648                rope_type: Llama3RopeType::Default,
649                ..
650            }) => Ok(Self(RotaryEmbedding::new(
651                cfg.rope_theta,
652                cfg.hidden_size / cfg.num_attention_heads,
653                cfg.max_position_embeddings,
654                dev,
655                is_gpt_neox,
656                dtype,
657            )?)),
658            Some(rope_scaling) => {
659                let low_freq_wavelen = rope_scaling.original_max_position_embeddings as f32
660                    / rope_scaling.low_freq_factor;
661                let high_freq_wavelen = rope_scaling.original_max_position_embeddings as f32
662                    / rope_scaling.high_freq_factor;
663
664                let inv_freq = calculate_default_inv_freq(cfg)
665                    .into_iter()
666                    .map(|freq| {
667                        let wavelen = 2. * PI / freq;
668                        if wavelen < high_freq_wavelen {
669                            freq
670                        } else if wavelen > low_freq_wavelen {
671                            freq / rope_scaling.factor
672                        } else {
673                            let smooth = (rope_scaling.original_max_position_embeddings as f32
674                                / wavelen
675                                - rope_scaling.low_freq_factor)
676                                / (rope_scaling.high_freq_factor - rope_scaling.low_freq_factor);
677                            (1. - smooth) * freq / rope_scaling.factor + smooth * freq
678                        }
679                    })
680                    .collect::<Vec<_>>();
681                let inv_freq_len = inv_freq.len();
682                let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
683
684                let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
685                    .to_dtype(DType::F32)?
686                    .reshape((cfg.max_position_embeddings, 1))?;
687                let freqs = t.matmul(&inv_freq)?;
688                let sin = freqs.sin()?.to_dtype(dtype)?;
689                let cos = freqs.cos()?.to_dtype(dtype)?;
690                Ok(Self(RotaryEmbedding {
691                    sin,
692                    cos,
693                    is_gpt_neox,
694                }))
695            }
696        }
697    }
698
699    pub fn new_llama4(
700        dtype: DType,
701        cfg: &llama4::TextConfig,
702        dev: &Device,
703        is_gpt_neox: bool,
704    ) -> Result<Self> {
705        match &cfg.rope_scaling {
706            None
707            | Some(Llama3RopeConfig {
708                rope_type: Llama3RopeType::Default,
709                ..
710            }) => Ok(Self(RotaryEmbedding::new(
711                cfg.rope_theta,
712                cfg.hidden_size / cfg.num_attention_heads,
713                cfg.max_position_embeddings,
714                dev,
715                is_gpt_neox,
716                dtype,
717            )?)),
718            Some(rope_scaling) => {
719                let low_freq_wavelen = rope_scaling.original_max_position_embeddings as f32
720                    / rope_scaling.low_freq_factor;
721                let high_freq_wavelen = rope_scaling.original_max_position_embeddings as f32
722                    / rope_scaling.high_freq_factor;
723
724                let inv_freq = calculate_default_inv_freq_llama4(cfg)
725                    .into_iter()
726                    .map(|freq| {
727                        let wavelen = 2. * PI / freq;
728                        if wavelen < high_freq_wavelen {
729                            freq
730                        } else if wavelen > low_freq_wavelen {
731                            freq / rope_scaling.factor
732                        } else {
733                            let smooth = (rope_scaling.original_max_position_embeddings as f32
734                                / wavelen
735                                - rope_scaling.low_freq_factor)
736                                / (rope_scaling.high_freq_factor - rope_scaling.low_freq_factor);
737                            (1. - smooth) * freq / rope_scaling.factor + smooth * freq
738                        }
739                    })
740                    .collect::<Vec<_>>();
741                let inv_freq_len = inv_freq.len();
742                let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
743
744                let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
745                    .to_dtype(DType::F32)?
746                    .reshape((cfg.max_position_embeddings, 1))?;
747                let freqs = t.matmul(&inv_freq)?;
748                let sin = freqs.sin()?.to_dtype(dtype)?;
749                let cos = freqs.cos()?.to_dtype(dtype)?;
750                Ok(Self(RotaryEmbedding {
751                    sin,
752                    cos,
753                    is_gpt_neox,
754                }))
755            }
756        }
757    }
758
759    pub fn new_mllama3(
760        dtype: DType,
761        cfg: &MLlamaTextConfig,
762        dev: &Device,
763        is_gpt_neox: bool,
764    ) -> Result<Self> {
765        match &cfg.rope_scaling {
766            None
767            | Some(MLlamaRopeScaling {
768                rope_type: MLlamaRopeType::Default,
769                ..
770            }) => Ok(Self(RotaryEmbedding::new(
771                cfg.rope_theta,
772                cfg.hidden_size / cfg.num_attention_heads,
773                cfg.max_position_embeddings,
774                dev,
775                is_gpt_neox,
776                dtype,
777            )?)),
778            Some(MLlamaRopeScaling {
779                rope_type: MLlamaRopeType::Llama3,
780                original_max_position_embeddings,
781                factor,
782                attention_factor: _,
783                beta_fast: _,
784                beta_slow: _,
785                short_factor: _,
786                long_factor: _,
787                low_freq_factor,
788                high_freq_factor,
789            }) => {
790                let factor = factor.context("MLlama Llama3 RoPE needs `factor` parameter.")?;
791                let low_freq_factor = low_freq_factor
792                    .context("MLlama Llama3 RoPE needs `low_freq_factor` parameter.")?;
793                let high_freq_factor = high_freq_factor
794                    .context("MLlama Llama3 RoPE needs `high_freq_factor` parameter.")?;
795
796                let low_freq_wavelen = *original_max_position_embeddings as f32 / low_freq_factor;
797                let high_freq_wavelen = *original_max_position_embeddings as f32 / high_freq_factor;
798
799                let head_dim = cfg.hidden_size / cfg.num_attention_heads;
800
801                let inv_freq = (0..head_dim)
802                    .step_by(2)
803                    .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
804                    .map(|freq| {
805                        let wavelen = 2. * PI / freq;
806                        if wavelen < high_freq_wavelen {
807                            freq
808                        } else if wavelen > low_freq_wavelen {
809                            freq / factor
810                        } else {
811                            let smooth = (*original_max_position_embeddings as f32 / wavelen
812                                - low_freq_factor)
813                                / (high_freq_factor - low_freq_factor);
814                            (1. - smooth) * freq / factor + smooth * freq
815                        }
816                    })
817                    .collect::<Vec<_>>();
818                let inv_freq_len = inv_freq.len();
819                let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
820
821                let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
822                    .to_dtype(DType::F32)?
823                    .reshape((cfg.max_position_embeddings, 1))?;
824                let freqs = t.matmul(&inv_freq)?;
825                let sin = freqs.sin()?.to_dtype(dtype)?;
826                let cos = freqs.cos()?.to_dtype(dtype)?;
827                Ok(Self(RotaryEmbedding {
828                    sin,
829                    cos,
830                    is_gpt_neox,
831                }))
832            }
833            Some(MLlamaRopeScaling {
834                rope_type: other, ..
835            }) => {
836                candle_core::bail!(
837                    "MLlama doesn't support any other RoPE type than `llama3`, got {other:?}"
838                )
839            }
840        }
841    }
842
843    pub fn forward(
844        &self,
845        q: &Tensor,
846        k: &Tensor,
847        seqlen_offsets: &[usize],
848    ) -> Result<(Tensor, Tensor)> {
849        self.0.forward(q, k, seqlen_offsets)
850    }
851}
852
853// https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L107
854#[derive(Debug, Clone)]
855pub struct Qwen2VLRotaryEmbedding {
856    inv_freq: Tensor,
857    mrope_section: Vec<usize>,
858}
859
860impl Qwen2VLRotaryEmbedding {
861    pub fn new(
862        base: f32,
863        head_dim: usize,
864        device: &Device,
865        mrope_section: Vec<usize>,
866    ) -> Result<Self> {
867        let inv_freq: Vec<_> = (0..head_dim)
868            .step_by(2)
869            .map(|i| 1f32 / base.powf(i as f32 / head_dim as f32))
870            .collect();
871        let inv_freq_len = inv_freq.len();
872        let inv_freq = Tensor::from_vec(inv_freq, (inv_freq_len,), device)?.to_dtype(DType::F32)?;
873        Ok(Self {
874            inv_freq,
875            mrope_section,
876        })
877    }
878
879    /// (cos, sin)
880    pub fn compute_cos_sin(&self, position_ids: &Tensor, dtype: DType) -> Result<(Tensor, Tensor)> {
881        let inv_freq_expanded =
882            self.inv_freq
883                .reshape((1, 1, (), 1))?
884                .repeat((3, position_ids.dim(1)?, 1, 1))?;
885        let position_ids_expanded = position_ids.unsqueeze(2)?;
886        let freqs = inv_freq_expanded
887            .matmul(&position_ids_expanded.to_dtype(inv_freq_expanded.dtype())?)?
888            .transpose(2, 3)?;
889        let cos = freqs.cos()?;
890        let sin = freqs.sin()?;
891
892        let cos = Tensor::cat(
893            &cos.split(&self.mrope_section, D::Minus1)?
894                .into_iter()
895                .enumerate()
896                .map(|(i, m)| m.i(i % 3))
897                .collect::<Result<Vec<_>>>()?,
898            D::Minus1,
899        )?
900        .squeeze(0)?
901        .to_dtype(dtype)?
902        .contiguous()?;
903        let sin = Tensor::cat(
904            &sin.split(&self.mrope_section, D::Minus1)?
905                .into_iter()
906                .enumerate()
907                .map(|(i, m)| m.i(i % 3))
908                .collect::<Result<Vec<_>>>()?,
909            D::Minus1,
910        )?
911        .squeeze(0)?
912        .to_dtype(dtype)?
913        .contiguous()?;
914
915        Ok((cos, sin))
916    }
917
918    // https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L203
919    pub fn forward(
920        &self,
921        (cos, sin): &(Tensor, Tensor),
922        q: &mut Tensor,
923        k: &mut Tensor,
924    ) -> Result<()> {
925        *q = candle_nn::rotary_emb::rope(&q.contiguous()?, cos, sin)?;
926        *k = candle_nn::rotary_emb::rope(&k.contiguous()?, cos, sin)?;
927        Ok(())
928    }
929}
930
931#[derive(Debug, Clone)]
932pub struct Qwen2_5VLRotaryEmbedding {
933    inv_freq: Tensor,
934    mrope_section: Vec<usize>,
935}
936
937impl Qwen2_5VLRotaryEmbedding {
938    pub fn new(
939        base: f32,
940        head_dim: usize,
941        device: &Device,
942        mrope_section: Vec<usize>,
943    ) -> Result<Self> {
944        let inv_freq: Vec<_> = (0..head_dim)
945            .step_by(2)
946            .map(|i| 1f32 / base.powf(i as f32 / head_dim as f32))
947            .collect();
948        let inv_freq_len = inv_freq.len();
949        let inv_freq = Tensor::from_vec(inv_freq, (inv_freq_len,), device)?.to_dtype(DType::F32)?;
950        Ok(Self {
951            inv_freq,
952            mrope_section,
953        })
954    }
955
956    /// (cos, sin)
957    pub fn compute_cos_sin(&self, position_ids: &Tensor, dtype: DType) -> Result<(Tensor, Tensor)> {
958        let inv_freq_expanded =
959            self.inv_freq
960                .reshape((1, 1, (), 1))?
961                .repeat((3, position_ids.dim(1)?, 1, 1))?;
962        let position_ids_expanded = position_ids.unsqueeze(2)?;
963        let freqs = inv_freq_expanded
964            .matmul(&position_ids_expanded.to_dtype(inv_freq_expanded.dtype())?)?
965            .transpose(2, 3)?;
966        let cos = freqs.cos()?;
967        let sin = freqs.sin()?;
968
969        let cos = Tensor::cat(
970            &cos.split(&self.mrope_section, D::Minus1)?
971                .into_iter()
972                .enumerate()
973                .map(|(i, m)| m.i(i % 3))
974                .collect::<Result<Vec<_>>>()?,
975            D::Minus1,
976        )?
977        .squeeze(0)?
978        .to_dtype(dtype)?
979        .contiguous()?;
980        let sin = Tensor::cat(
981            &sin.split(&self.mrope_section, D::Minus1)?
982                .into_iter()
983                .enumerate()
984                .map(|(i, m)| m.i(i % 3))
985                .collect::<Result<Vec<_>>>()?,
986            D::Minus1,
987        )?
988        .squeeze(0)?
989        .to_dtype(dtype)?
990        .contiguous()?;
991
992        Ok((cos, sin))
993    }
994
995    pub fn forward(
996        &self,
997        (cos, sin): &(Tensor, Tensor),
998        q: &mut Tensor,
999        k: &mut Tensor,
1000    ) -> Result<()> {
1001        *q = candle_nn::rotary_emb::rope(&q.contiguous()?, cos, sin)?;
1002        *k = candle_nn::rotary_emb::rope(&k.contiguous()?, cos, sin)?;
1003        Ok(())
1004    }
1005}
1006
1007#[derive(Debug, Clone)]
1008pub struct DeepSeekV2RotaryEmbedding {
1009    sin: Tensor,
1010    cos: Tensor,
1011}
1012
1013#[derive(Debug, Clone, Deserialize, Serialize)]
1014#[serde(untagged)]
1015pub enum DeepSeekV2RopeScaling {
1016    Yarn {
1017        original_max_position_embeddings: usize,
1018        beta_fast: f32,
1019        beta_slow: f32,
1020        mscale: f32,
1021        mscale_all_dim: f32,
1022        factor: f32,
1023        #[serde(rename = "type")]
1024        scaling_type: ScaledRopeType,
1025    },
1026    LinearOrDynamic {
1027        #[serde(rename = "type")]
1028        scaling_type: ScaledRopeType,
1029        factor: f64,
1030    },
1031}
1032
1033pub struct DeepSeekV2RopeConfig {
1034    pub rope_scaling: Option<DeepSeekV2RopeScaling>,
1035    pub max_position_embeddings: usize,
1036    pub rope_theta: f32,
1037    pub qk_rope_head_dim: usize,
1038}
1039
1040impl DeepSeekV2RotaryEmbedding {
1041    fn new_unscaled(cfg: &DeepSeekV2RopeConfig, dtype: DType, dev: &Device) -> Result<Self> {
1042        let max_seq_len = cfg.max_position_embeddings;
1043        let dim = cfg.qk_rope_head_dim;
1044
1045        let inv_freq: Vec<_> = (0..dim)
1046            .step_by(2)
1047            .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / dim as f32))
1048            .collect();
1049        let inv_freq_len = inv_freq.len();
1050        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
1051        let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
1052            .to_dtype(DType::F32)?
1053            .reshape((max_seq_len, 1))?;
1054        let freqs = t.matmul(&inv_freq)?;
1055
1056        let sin = freqs.sin()?.to_dtype(dtype)?;
1057        let cos = freqs.cos()?.to_dtype(dtype)?;
1058
1059        Ok(Self { sin, cos })
1060    }
1061
1062    fn yarn_find_correction_dim(
1063        num_rot: f32,
1064        dim: usize,
1065        base: f32,
1066        max_position_embeddings: usize,
1067    ) -> f32 {
1068        (dim as f32 * (max_position_embeddings as f32 / (num_rot * 2. * PI)).ln())
1069            / (2. * base.ln())
1070    }
1071
1072    fn yarn_find_correction_range(
1073        low_rot: f32,
1074        high_rot: f32,
1075        dim: usize,
1076        base: f32,
1077        max_position_embeddings: usize,
1078    ) -> (f32, f32) {
1079        let low =
1080            Self::yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings).floor();
1081        let high =
1082            Self::yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings).ceil();
1083        (low.max(0.), high.min(dim as f32 - 1.))
1084    }
1085
1086    fn yarn_linear_ramp_mask(min: f32, mut max: f32, dim: usize, dev: &Device) -> Result<Tensor> {
1087        if min == max {
1088            // https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite/blob/604d5664dddd88a0433dbae533b7fe9472482de0/modeling_deepseek.py#L255
1089            max += 0.001;
1090        }
1091        let linear_func =
1092            ((Tensor::arange(0f32, dim as f32, dev)? - min as f64)? / (max as f64 - min as f64))?;
1093        linear_func.clamp(0., 1)
1094    }
1095
1096    pub(crate) fn yarn_get_mscale(scale: f32, mscale: f32) -> f32 {
1097        if scale <= 1. {
1098            return 1.;
1099        }
1100        0.1 * mscale * scale.ln() + 1.
1101    }
1102
1103    #[allow(clippy::too_many_arguments)]
1104    fn new_yarn(
1105        cfg: &DeepSeekV2RopeConfig,
1106        dtype: DType,
1107        dev: &Device,
1108        original_max_position_embeddings: usize,
1109        beta_fast: f32,
1110        beta_slow: f32,
1111        factor: f32,
1112        mscale: f32,
1113        mscale_all_dim: f32,
1114    ) -> Result<Self> {
1115        let freq_extra: Vec<_> = (0..cfg.qk_rope_head_dim)
1116            .step_by(2)
1117            .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / cfg.qk_rope_head_dim as f32))
1118            .collect();
1119        let freq_extra_len = freq_extra.len();
1120        let freq_extra = Tensor::from_vec(freq_extra, freq_extra_len, dev)?;
1121        let freq_inter: Vec<_> = (0..cfg.qk_rope_head_dim)
1122            .step_by(2)
1123            .map(|i| 1f32 / (factor * cfg.rope_theta.powf(i as f32 / cfg.qk_rope_head_dim as f32)))
1124            .collect();
1125        let freq_inter_len = freq_inter.len();
1126        let freq_inter = Tensor::from_vec(freq_inter, (1, freq_inter_len), dev)?;
1127
1128        let (low, high) = Self::yarn_find_correction_range(
1129            beta_fast,
1130            beta_slow,
1131            cfg.qk_rope_head_dim,
1132            cfg.rope_theta,
1133            original_max_position_embeddings,
1134        );
1135        let inv_freq_mask =
1136            (1. - Self::yarn_linear_ramp_mask(low, high, cfg.qk_rope_head_dim / 2, dev)?)?;
1137        let inv_freq = freq_inter
1138            .broadcast_mul(&(1. - &inv_freq_mask)?)?
1139            .broadcast_add(&freq_extra.broadcast_mul(&inv_freq_mask)?)?;
1140
1141        let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
1142            .to_dtype(DType::F32)?
1143            .reshape((cfg.max_position_embeddings, 1))?;
1144        let freqs = t.matmul(&inv_freq)?;
1145
1146        let mscale =
1147            Self::yarn_get_mscale(factor, mscale) / Self::yarn_get_mscale(factor, mscale_all_dim);
1148        let sin = (freqs.sin()? * mscale as f64)?.to_dtype(dtype)?;
1149        let cos = (freqs.cos()? * mscale as f64)?.to_dtype(dtype)?;
1150
1151        Ok(Self { sin, cos })
1152    }
1153
1154    pub fn new(cfg: &DeepSeekV2RopeConfig, dtype: DType, dev: &Device) -> Result<Self> {
1155        match &cfg.rope_scaling {
1156            Some(DeepSeekV2RopeScaling::LinearOrDynamic {
1157                scaling_type: _,
1158                factor: _,
1159            }) => candle_core::bail!("linear and dynamic rope are not implemented yet!"),
1160            Some(DeepSeekV2RopeScaling::Yarn {
1161                original_max_position_embeddings,
1162                beta_fast,
1163                beta_slow,
1164                factor,
1165                mscale,
1166                mscale_all_dim,
1167                scaling_type: _,
1168            }) => Self::new_yarn(
1169                cfg,
1170                dtype,
1171                dev,
1172                *original_max_position_embeddings,
1173                *beta_fast,
1174                *beta_slow,
1175                *factor,
1176                *mscale,
1177                *mscale_all_dim,
1178            ),
1179            None => Self::new_unscaled(cfg, dtype, dev),
1180        }
1181    }
1182
1183    pub fn forward(
1184        &self,
1185        q: &Tensor,
1186        k: &Tensor,
1187        seqlen_offsets: &[usize],
1188    ) -> Result<(Tensor, Tensor)> {
1189        let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
1190
1191        if seqlen_offsets.len() == 1 {
1192            let cos = self.cos.narrow(0, seqlen_offsets[0], seq_len)?;
1193            let sin = self.sin.narrow(0, seqlen_offsets[0], seq_len)?;
1194            let q_embed = candle_nn::rotary_emb::rope_i(&q.contiguous()?, &cos, &sin)?;
1195            let k_embed = candle_nn::rotary_emb::rope_i(&k.contiguous()?, &cos, &sin)?;
1196            Ok((q_embed, k_embed))
1197        } else {
1198            let mut q_embeds = Vec::new();
1199            let mut k_embeds = Vec::new();
1200            for (i, offset) in seqlen_offsets.iter().enumerate() {
1201                let cos = self.cos.narrow(0, *offset, seq_len)?;
1202                let sin = self.sin.narrow(0, *offset, seq_len)?;
1203                let q_embed = candle_nn::rotary_emb::rope_i(
1204                    &q.i(i)?.unsqueeze(0)?.contiguous()?,
1205                    &cos,
1206                    &sin,
1207                )?;
1208                let k_embed = candle_nn::rotary_emb::rope_i(
1209                    &k.i(i)?.unsqueeze(0)?.contiguous()?,
1210                    &cos,
1211                    &sin,
1212                )?;
1213                q_embeds.push(q_embed);
1214                k_embeds.push(k_embed);
1215            }
1216            Ok((Tensor::cat(&q_embeds, 0)?, Tensor::cat(&k_embeds, 0)?))
1217        }
1218    }
1219}
1220
1221#[derive(Debug, Clone)]
1222pub struct Phi4MMRotaryEmbedding {
1223    short_sin: Tensor,
1224    short_cos: Tensor,
1225    long_cos: Option<Tensor>,
1226    long_sin: Option<Tensor>,
1227    original_max_position_embeddings: usize,
1228}
1229
1230#[derive(Debug, Clone, Default, Deserialize, Serialize)]
1231#[serde(rename_all = "lowercase")]
1232pub enum Phi4MMScaledRopeType {
1233    #[serde(alias = "longrope")]
1234    LongRope,
1235    #[default]
1236    Default,
1237}
1238
1239#[derive(Debug, Clone, Deserialize, Serialize)]
1240pub struct Phi4MMRopeScalingConfig {
1241    short_factor: Option<Vec<f64>>,
1242    long_factor: Option<Vec<f64>>,
1243    #[serde(rename = "type")]
1244    scaling_type: Phi4MMScaledRopeType,
1245}
1246
1247impl Phi4MMRotaryEmbedding {
1248    fn new_unscaled(cfg: &Phi4MMConfig, dtype: DType, dev: &Device) -> Result<Self> {
1249        let max_seq_len = cfg.max_position_embeddings;
1250        let dim = (cfg.head_dim() as f64 * cfg.partial_rotary_factor) as usize;
1251
1252        let inv_freq: Vec<_> = (0..dim)
1253            .step_by(2)
1254            .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
1255            .collect();
1256        let inv_freq_len = inv_freq.len();
1257        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
1258        let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
1259            .to_dtype(DType::F32)?
1260            .reshape((max_seq_len, 1))?;
1261        let freqs = t.matmul(&inv_freq)?;
1262        let sin = freqs.sin()?.to_dtype(dtype)?;
1263        let cos = freqs.cos()?.to_dtype(dtype)?;
1264        Ok(Self {
1265            short_cos: cos,
1266            short_sin: sin,
1267            long_cos: None,
1268            long_sin: None,
1269            original_max_position_embeddings: cfg.original_max_position_embeddings,
1270        })
1271    }
1272
1273    #[allow(clippy::too_many_arguments)]
1274    fn new_longrope(
1275        short_factor: &[f64],
1276        long_factor: &[f64],
1277        cfg: &Phi4MMConfig,
1278        dtype: DType,
1279        dev: &Device,
1280    ) -> Result<Self> {
1281        let max_seq_len = cfg.max_position_embeddings;
1282        let dim = (cfg.head_dim() as f64 * cfg.partial_rotary_factor) as usize;
1283
1284        // Calculate scale
1285        let scale =
1286            cfg.max_position_embeddings as f64 / cfg.original_max_position_embeddings as f64;
1287        let scaling_factor = if scale <= 1.0 {
1288            1.0
1289        } else {
1290            (1.0 + scale.ln() / (cfg.original_max_position_embeddings as f64).ln()).sqrt()
1291        };
1292
1293        // Short cos/sin
1294        let inv_freq_short: Vec<_> = (0..dim)
1295            .step_by(2)
1296            .enumerate()
1297            .map(|(k, i)| {
1298                1f32 / (short_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32
1299            })
1300            .collect();
1301        let inv_freq_len_short = inv_freq_short.len();
1302        let inv_freq_short = Tensor::from_vec(inv_freq_short, (1, inv_freq_len_short), dev)?;
1303        let t_short = Tensor::arange(0u32, max_seq_len as u32, dev)?
1304            .to_dtype(DType::F32)?
1305            .reshape((max_seq_len, 1))?;
1306        let freqs_short = t_short.matmul(&inv_freq_short)?;
1307        let sin_short = (freqs_short.sin()?.to_dtype(dtype)? * scaling_factor)?;
1308        let cos_short = (freqs_short.cos()?.to_dtype(dtype)? * scaling_factor)?;
1309
1310        // Long cos/sin
1311        let inv_freq_long: Vec<_> = (0..dim)
1312            .step_by(2)
1313            .enumerate()
1314            .map(|(k, i)| {
1315                1f32 / (long_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32
1316            })
1317            .collect();
1318        let inv_freq_len_long = inv_freq_long.len();
1319        let inv_freq_long = Tensor::from_vec(inv_freq_long, (1, inv_freq_len_long), dev)?;
1320        let t_long = Tensor::arange(0u32, max_seq_len as u32, dev)?
1321            .to_dtype(DType::F32)?
1322            .reshape((max_seq_len, 1))?;
1323        let freqs_long = t_long.matmul(&inv_freq_long)?;
1324        let sin_long = (freqs_long.sin()?.to_dtype(dtype)? * scaling_factor)?;
1325        let cos_long = (freqs_long.cos()?.to_dtype(dtype)? * scaling_factor)?;
1326
1327        Ok(Self {
1328            short_cos: cos_short,
1329            short_sin: sin_short,
1330            long_cos: Some(cos_long),
1331            long_sin: Some(sin_long),
1332            original_max_position_embeddings: cfg.original_max_position_embeddings,
1333        })
1334    }
1335
1336    pub fn new(dtype: DType, cfg: &Phi4MMConfig, dev: &Device) -> Result<Self> {
1337        match &cfg.rope_scaling {
1338            Some(Phi4MMRopeScalingConfig {
1339                scaling_type: Phi4MMScaledRopeType::LongRope,
1340                short_factor: Some(short_factor),
1341                long_factor: Some(long_factor),
1342            }) => Self::new_longrope(short_factor, long_factor, cfg, dtype, dev),
1343
1344            _ => Self::new_unscaled(cfg, dtype, dev),
1345        }
1346    }
1347
1348    /// Returns (sin, cos) taking into account LongRope
1349    fn get_long_or_short_sin_cos(&self, position_ids: &[usize]) -> (&Tensor, &Tensor) {
1350        if self.long_cos.is_none() {
1351            return (&self.short_sin, &self.short_cos);
1352        }
1353        let seq_len = position_ids.iter().max().unwrap() + 1;
1354        if seq_len > self.original_max_position_embeddings {
1355            (
1356                self.long_sin.as_ref().unwrap(),
1357                self.long_cos.as_ref().unwrap(),
1358            )
1359        } else {
1360            (&self.short_sin, &self.short_cos)
1361        }
1362    }
1363
1364    pub fn forward(
1365        &self,
1366        q: &Tensor,
1367        k: &Tensor,
1368        seqlen_offsets: &[usize],
1369        position_ids: &[usize],
1370    ) -> Result<(Tensor, Tensor)> {
1371        let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
1372        let (sin, cos) = self.get_long_or_short_sin_cos(position_ids);
1373
1374        let rot_dim = cos.dim(D::Minus1)? * 2;
1375        let q_rot = q.narrow(D::Minus1, 0, rot_dim)?;
1376        let q_pass = q.narrow(D::Minus1, rot_dim, q.dim(D::Minus1)? - rot_dim)?;
1377        let k_rot = k.narrow(D::Minus1, 0, rot_dim)?;
1378        let k_pass = k.narrow(D::Minus1, rot_dim, k.dim(D::Minus1)? - rot_dim)?;
1379
1380        let (q_rot, k_rot) = if seqlen_offsets.len() == 1 {
1381            let cos = cos.narrow(0, seqlen_offsets[0], seq_len)?;
1382            let sin = sin.narrow(0, seqlen_offsets[0], seq_len)?;
1383            let q_embed = candle_nn::rotary_emb::rope(&q_rot.contiguous()?, &cos, &sin)?;
1384            let k_embed = candle_nn::rotary_emb::rope(&k_rot.contiguous()?, &cos, &sin)?;
1385            (q_embed, k_embed)
1386        } else {
1387            let mut q_embeds = Vec::new();
1388            let mut k_embeds = Vec::new();
1389            for (i, offset) in seqlen_offsets.iter().enumerate() {
1390                let cos = cos.narrow(0, *offset, seq_len)?;
1391                let sin = sin.narrow(0, *offset, seq_len)?;
1392                let q_embed = candle_nn::rotary_emb::rope(
1393                    &q_rot.i(i)?.unsqueeze(0)?.contiguous()?,
1394                    &cos,
1395                    &sin,
1396                )?;
1397                let k_embed = candle_nn::rotary_emb::rope(
1398                    &k_rot.i(i)?.unsqueeze(0)?.contiguous()?,
1399                    &cos,
1400                    &sin,
1401                )?;
1402                q_embeds.push(q_embed);
1403                k_embeds.push(k_embed);
1404            }
1405            let q_rot = Tensor::cat(&q_embeds, 0)?;
1406            let k_rot = Tensor::cat(&k_embeds, 0)?;
1407            (q_rot, k_rot)
1408        };
1409
1410        Ok((
1411            Tensor::cat(&[q_rot, q_pass], D::Minus1)?.contiguous()?,
1412            Tensor::cat(&[k_rot, k_pass], D::Minus1)?.contiguous()?,
1413        ))
1414    }
1415}
1416
1417#[derive(Debug, Clone)]
1418pub struct Gemma3RotaryEmbedding(RotaryEmbedding);
1419
1420#[derive(Debug, Clone, Deserialize, Serialize)]
1421#[serde(rename_all = "lowercase")]
1422pub enum Gemmma3ScaledRopeType {
1423    #[serde(alias = "linear")]
1424    Linear,
1425}
1426
1427#[derive(Debug, Clone, Deserialize, Serialize)]
1428pub struct Gemma3RopeScalingConfig {
1429    factor: f64,
1430    rope_type: Gemmma3ScaledRopeType,
1431}
1432
1433impl Gemma3RotaryEmbedding {
1434    fn new_linear(
1435        cfg: &Gemma3TextConfig,
1436        factor: f64,
1437        is_gpt_neox: bool,
1438        dtype: DType,
1439        dev: &Device,
1440    ) -> Result<Self> {
1441        let max_seq_len = cfg.max_position_embeddings;
1442        let dim = cfg.head_dim;
1443
1444        let inv_freq: Vec<_> = (0..dim)
1445            .step_by(2)
1446            .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
1447            .collect();
1448        let inv_freq_len = inv_freq.len();
1449        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
1450        let inv_freq = (inv_freq / factor)?;
1451
1452        let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
1453            .to_dtype(DType::F32)?
1454            .reshape((max_seq_len, 1))?;
1455        let freqs = t.matmul(&inv_freq)?;
1456        let sin = freqs.sin()?.to_dtype(dtype)?;
1457        let cos = freqs.cos()?.to_dtype(dtype)?;
1458        Ok(Self(RotaryEmbedding {
1459            cos,
1460            sin,
1461            is_gpt_neox,
1462        }))
1463    }
1464
1465    pub fn new(
1466        is_gpt_neox: bool,
1467        dtype: DType,
1468        cfg: &Gemma3TextConfig,
1469        dev: &Device,
1470    ) -> Result<Self> {
1471        match &cfg.rope_scaling {
1472            Some(Gemma3RopeScalingConfig {
1473                rope_type: Gemmma3ScaledRopeType::Linear,
1474                factor,
1475            }) => Self::new_linear(cfg, *factor, is_gpt_neox, dtype, dev),
1476
1477            _ => Self::new_linear(cfg, 1.0, is_gpt_neox, dtype, dev),
1478        }
1479    }
1480
1481    pub fn forward(
1482        &self,
1483        q: &Tensor,
1484        k: &Tensor,
1485        seqlen_offsets: &[usize],
1486    ) -> Result<(Tensor, Tensor)> {
1487        self.0.forward(q, k, seqlen_offsets)
1488    }
1489}
1490
1491#[derive(Debug, Clone)]
1492pub struct QLinear {
1493    inner: QMatMul,
1494    bias: Option<Tensor>,
1495    dtype: DType,
1496}
1497
1498impl QLinear {
1499    pub fn new<R: std::io::Read + std::io::Seek>(
1500        ct: &mut Content<'_, R>,
1501        name: &str,
1502        device: &Device,
1503    ) -> Result<Self> {
1504        let w = ct.tensor(&format!("{name}.weight"), device)?;
1505        let b = ct.tensor(&format!("{name}.bias"), device)?;
1506        let inner = QMatMul::from_qtensor(w)?;
1507        let bias = b.dequantize(device)?;
1508        Ok(Self {
1509            inner,
1510            bias: Some(bias),
1511            dtype: DType::F32,
1512        })
1513    }
1514
1515    pub fn from_linear(linear: Linear) -> Self {
1516        Self {
1517            inner: QMatMul::Tensor(linear.weight().clone()),
1518            bias: linear.bias().cloned(),
1519            dtype: linear.weight().dtype(),
1520        }
1521    }
1522
1523    pub fn from_parts(w: Tensor, b: Option<Tensor>) -> Self {
1524        let dtype = w.dtype();
1525        Self {
1526            inner: QMatMul::Tensor(w),
1527            bias: b,
1528            dtype,
1529        }
1530    }
1531
1532    pub fn from_qparts(w: QTensor, b: Option<Tensor>) -> Self {
1533        if let Some(ref b) = b {
1534            assert_eq!(b.dtype(), DType::F32);
1535        }
1536        Self {
1537            inner: QMatMul::QTensor(Arc::new(w)),
1538            bias: b,
1539            dtype: DType::F32,
1540        }
1541    }
1542
1543    pub fn from_old_and_qmatmul(inner: QMatMul, old: &Self) -> Self {
1544        Self {
1545            inner,
1546            bias: old.bias.clone(),
1547            dtype: old.dtype,
1548        }
1549    }
1550
1551    pub fn inner(&mut self) -> &mut QMatMul {
1552        &mut self.inner
1553    }
1554
1555    pub fn inner_ref(&self) -> &QMatMul {
1556        &self.inner
1557    }
1558
1559    pub fn is_quant(&self) -> bool {
1560        matches!(self.inner, QMatMul::QTensor(_))
1561    }
1562
1563    pub fn bias(&self) -> Option<&Tensor> {
1564        self.bias.as_ref()
1565    }
1566
1567    pub fn bias_mut(&mut self) -> Option<&mut Tensor> {
1568        self.bias.as_mut()
1569    }
1570}
1571
1572impl Module for QLinear {
1573    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
1574        let xs = if self.is_quant() {
1575            xs.to_dtype(DType::F32)?
1576        } else {
1577            xs.clone()
1578        };
1579        if let Some(bias) = &self.bias {
1580            self.inner
1581                .forward(&xs)?
1582                .broadcast_add(bias)?
1583                .to_dtype(self.dtype)
1584        } else {
1585            self.inner.forward(&xs)?.to_dtype(self.dtype)
1586        }
1587    }
1588}
1589
1590#[derive(Debug, Clone)]
1591pub struct RotaryEmbedding {
1592    cos: Tensor,
1593    sin: Tensor,
1594    is_gpt_neox: bool,
1595}
1596
1597impl RotaryEmbedding {
1598    pub fn new(
1599        base: f32,
1600        head_dim: usize,
1601        max_position_embeddings: usize,
1602        device: &Device,
1603        is_gpt_neox: bool,
1604        dtype: DType,
1605    ) -> Result<Self> {
1606        let inv_freq: Vec<_> = (0..head_dim)
1607            .step_by(2)
1608            .map(|i| 1f32 / base.powf(i as f32 / head_dim as f32))
1609            .collect();
1610        let inv_freq_len = inv_freq.len();
1611        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), device)?;
1612        let t = Tensor::arange(0u32, max_position_embeddings as u32, device)?
1613            .to_dtype(DType::F32)?
1614            .reshape((max_position_embeddings, 1))?;
1615        let freqs = t.matmul(&inv_freq)?;
1616        let sin = freqs.sin()?.to_dtype(dtype)?;
1617        let cos = freqs.cos()?.to_dtype(dtype)?;
1618
1619        Ok(Self {
1620            cos,
1621            sin,
1622            is_gpt_neox,
1623        })
1624    }
1625
1626    pub fn new_partial(
1627        base: f32,
1628        rot_dim: usize,
1629        max_position_embeddings: usize,
1630        device: &Device,
1631        is_gpt_neox: bool,
1632        dtype: DType,
1633    ) -> Result<Self> {
1634        let inv_freq: Vec<_> = (0..rot_dim)
1635            .step_by(2)
1636            .map(|i| 1f32 / base.powf(i as f32 / rot_dim as f32))
1637            .collect();
1638        let inv_freq_len = inv_freq.len();
1639        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), device)?;
1640        let t = Tensor::arange(0u32, max_position_embeddings as u32, device)?
1641            .to_dtype(DType::F32)?
1642            .reshape((max_position_embeddings, 1))?;
1643        let freqs = t.matmul(&inv_freq)?;
1644        let sin = freqs.sin()?.to_dtype(dtype)?;
1645        let cos = freqs.cos()?.to_dtype(dtype)?;
1646
1647        Ok(Self {
1648            cos,
1649            sin,
1650            is_gpt_neox,
1651        })
1652    }
1653
1654    pub fn forward(
1655        &self,
1656        q: &Tensor,
1657        k: &Tensor,
1658        seqlen_offsets: &[usize],
1659    ) -> Result<(Tensor, Tensor)> {
1660        let (b_sz, qh, seq_len, n_embd) = q.dims4()?;
1661        let (_b_sz, kh, _seq_len, __n_embd) = k.dims4()?;
1662
1663        let rope = if self.is_gpt_neox {
1664            candle_nn::rotary_emb::rope
1665        } else {
1666            candle_nn::rotary_emb::rope_i
1667        };
1668
1669        if cfg!(feature = "cuda") && qh == kh {
1670            let (cos, sin) = if seqlen_offsets.len() == 1 {
1671                (
1672                    self.cos.narrow(0, seqlen_offsets[0], seq_len)?,
1673                    self.sin.narrow(0, seqlen_offsets[0], seq_len)?,
1674                )
1675            } else {
1676                let mut cos_s = Vec::new();
1677                let mut sin_s = Vec::new();
1678                for offset in seqlen_offsets {
1679                    cos_s.push(self.cos.narrow(0, *offset, seq_len)?);
1680                    sin_s.push(self.sin.narrow(0, *offset, seq_len)?);
1681                }
1682                (Tensor::cat(&cos_s, 0)?, Tensor::cat(&sin_s, 0)?)
1683            };
1684
1685            let q_embed = q.transpose(1, 2)?.flatten(0, 1)?;
1686            let k_embed = k.transpose(1, 2)?.flatten(0, 1)?;
1687            mistralrs_quant::rotary::apply_rotary_inplace(
1688                &q_embed,
1689                &k_embed,
1690                &cos,
1691                &sin,
1692                self.is_gpt_neox,
1693            )?;
1694            let mut q = q_embed
1695                .reshape((b_sz, seq_len, qh, n_embd))?
1696                .transpose(1, 2)?;
1697            let mut k = k_embed
1698                .reshape((b_sz, seq_len, kh, n_embd))?
1699                .transpose(1, 2)?;
1700            if !(cfg!(feature = "flash-attn") || cfg!(feature = "flash-attn-v3")) {
1701                q = q.contiguous()?;
1702                k = k.contiguous()?;
1703            }
1704            Ok((q, k))
1705        } else if seqlen_offsets.len() == 1 {
1706            let cos = self.cos.narrow(0, seqlen_offsets[0], seq_len)?;
1707            let sin = self.sin.narrow(0, seqlen_offsets[0], seq_len)?;
1708            let q_embed = rope(&q.contiguous()?, &cos, &sin)?;
1709            let k_embed = rope(&k.contiguous()?, &cos, &sin)?;
1710            Ok((q_embed, k_embed))
1711        } else {
1712            let mut q_embeds = Vec::new();
1713            let mut k_embeds = Vec::new();
1714            for (i, offset) in seqlen_offsets.iter().enumerate() {
1715                let cos = self.cos.narrow(0, *offset, seq_len)?;
1716                let sin = self.sin.narrow(0, *offset, seq_len)?;
1717                let q_embed = rope(&q.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
1718                let k_embed = rope(&k.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
1719                q_embeds.push(q_embed);
1720                k_embeds.push(k_embed);
1721            }
1722            Ok((Tensor::cat(&q_embeds, 0)?, Tensor::cat(&k_embeds, 0)?))
1723        }
1724    }
1725}
1726
1727#[derive(Debug, Clone, Copy, PartialEq, Deserialize, Serialize, Default)]
1728#[serde(rename_all = "lowercase")]
1729pub enum Activation {
1730    #[default]
1731    #[serde(alias = "gelu")]
1732    Gelu,
1733    #[serde(alias = "gelu_new")]
1734    NewGelu,
1735    Relu,
1736    Relu2,
1737    Relu6,
1738    Silu,
1739    Sigmoid,
1740    HardSigmoid,
1741    Swiglu,
1742    Swish,
1743    HardSwish,
1744    Elu(f64),
1745    LeakyRelu(f64),
1746    #[serde(alias = "gelu_pytorch_tanh")]
1747    GeluPytorchTanh,
1748    QuickGelu,
1749}
1750
1751impl Module for Activation {
1752    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
1753        match self {
1754            Self::Gelu => xs.gelu_erf(),
1755            // https://github.com/huggingface/transformers/blob/12f043eaeaabfef6f6efea411d98e6f6d3c094b7/src/transformers/activations.py#L49-L78
1756            Self::NewGelu => xs.gelu(),
1757            Self::Relu => xs.relu(),
1758            Self::Relu2 => xs.relu()?.sqr(),
1759            Self::Relu6 => xs.clamp(0f32, 6f32),
1760            Self::Silu => xs.silu(),
1761            Self::Sigmoid => candle_nn::ops::sigmoid(xs),
1762            Self::HardSigmoid => candle_nn::ops::hard_sigmoid(xs),
1763            Self::Swiglu => candle_nn::ops::swiglu(xs),
1764            Self::Swish => xs * candle_nn::ops::sigmoid(xs)?,
1765            Self::HardSwish => xs * candle_nn::ops::hard_sigmoid(xs)?,
1766            &Self::Elu(alpha) => xs.elu(alpha),
1767            &Self::LeakyRelu(negative_slope) => candle_nn::ops::leaky_relu(xs, negative_slope),
1768            Self::GeluPytorchTanh => xs.gelu(),
1769            Self::QuickGelu => xs * candle_nn::ops::sigmoid(&(xs * 1.702f64)?),
1770        }
1771    }
1772}
1773
1774impl TryInto<candle_nn::Activation> for Activation {
1775    type Error = candle_core::Error;
1776
1777    fn try_into(self) -> Result<candle_nn::Activation> {
1778        match self {
1779            Self::Gelu => Ok(candle_nn::Activation::Gelu),
1780            Self::Relu => Ok(candle_nn::Activation::Relu),
1781            Self::Silu => Ok(candle_nn::Activation::Silu),
1782            Self::NewGelu => Ok(candle_nn::Activation::NewGelu),
1783            Self::Relu2 => Ok(candle_nn::Activation::Relu2),
1784            Self::Relu6 => Ok(candle_nn::Activation::Relu6),
1785            Self::Sigmoid => Ok(candle_nn::Activation::Sigmoid),
1786            Self::HardSigmoid => Ok(candle_nn::Activation::HardSigmoid),
1787            Self::Swiglu => Ok(candle_nn::Activation::Swiglu),
1788            Self::Swish => Ok(candle_nn::Activation::Swish),
1789            Self::HardSwish => Ok(candle_nn::Activation::HardSwish),
1790            Self::Elu(x) => Ok(candle_nn::Activation::Elu(x)),
1791            Self::LeakyRelu(x) => Ok(candle_nn::Activation::LeakyRelu(x)),
1792            Self::GeluPytorchTanh => Ok(candle_nn::Activation::GeluPytorchTanh),
1793            Self::QuickGelu => candle_core::bail!("No mapping to candle_nn for QuickGelu"),
1794        }
1795    }
1796}
1797
1798#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1799pub struct Conv3dConfig {
1800    pub padding: usize,
1801    pub stride: usize,
1802    pub dilation: usize,
1803    pub groups: usize,
1804}
1805
1806impl Default for Conv3dConfig {
1807    fn default() -> Self {
1808        Self {
1809            padding: 0,
1810            stride: 1,
1811            dilation: 1,
1812            groups: 1,
1813        }
1814    }
1815}
1816
1817pub struct Conv3dNoBias {
1818    conv2d_1: Conv2d,
1819    conv2d_2: Conv2d,
1820}
1821
1822impl Conv3dNoBias {
1823    pub fn new(
1824        in_channels: usize,
1825        out_channels: usize,
1826        kernel_sizes: [usize; 3],
1827        cfg: Conv3dConfig,
1828        vb: ShardedVarBuilder,
1829    ) -> Result<Self> {
1830        let ws = vb.get(
1831            (
1832                out_channels,
1833                in_channels / cfg.groups,
1834                kernel_sizes[0],
1835                kernel_sizes[1],
1836                kernel_sizes[2],
1837            ),
1838            "weight",
1839        )?;
1840
1841        // Split on temporal dimension
1842        // https://github.com/pytorch/pytorch/issues/139066
1843
1844        let w1 = ws.i((.., .., 0, .., ..))?;
1845        let w2 = ws.i((.., .., 1, .., ..))?;
1846
1847        let cfg = Conv2dConfig {
1848            padding: cfg.padding,
1849            stride: cfg.stride,
1850            dilation: cfg.dilation,
1851            groups: cfg.groups,
1852        };
1853
1854        Ok(Self {
1855            conv2d_1: Conv2d::new(w1.contiguous()?, None, cfg),
1856            conv2d_2: Conv2d::new(w2.contiguous()?, None, cfg),
1857        })
1858    }
1859}
1860
1861impl Module for Conv3dNoBias {
1862    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
1863        let xs1 = xs.i((.., .., 0, .., ..))?;
1864        let xs2 = xs.i((.., .., 1, .., ..))?;
1865
1866        (self.conv2d_1.forward(&xs1)? + self.conv2d_2.forward(&xs2)?)?.unsqueeze(2)
1867    }
1868}
1869
1870pub trait TensorInfExtend {
1871    fn is_inf(&self) -> Result<Self>
1872    where
1873        Self: Sized;
1874    fn any(&self) -> Result<bool>;
1875}
1876
1877impl TensorInfExtend for Tensor {
1878    fn is_inf(&self) -> Result<Self> {
1879        self.broadcast_eq(&Tensor::new(f32::INFINITY, self.device())?.to_dtype(self.dtype())?)
1880    }
1881
1882    fn any(&self) -> Result<bool> {
1883        let sum = self.sum_all()?;
1884        match self.dtype() {
1885            DType::U8 => Ok(sum.to_scalar::<u8>()? == 0),
1886            DType::U32 => Ok(sum.to_scalar::<u32>()? == 0),
1887            DType::I16 => Ok(sum.to_scalar::<i16>()? == 0),
1888            DType::I32 => Ok(sum.to_scalar::<i32>()? == 0),
1889            DType::I64 => Ok(sum.to_scalar::<i64>()? == 0),
1890            DType::F16 => Ok(sum.to_scalar::<half::f16>()? == half::f16::from_f32_const(0.)),
1891            DType::BF16 => Ok(sum.to_scalar::<half::bf16>()? == half::bf16::from_f32_const(0.)),
1892            DType::F32 => Ok(sum.to_scalar::<f32>()? == 0.),
1893            DType::F64 => Ok(sum.to_scalar::<f64>()? == 0.),
1894            DType::F8E4M3 => Ok(sum.to_scalar::<F8E4M3>()? == F8E4M3::ZERO),
1895        }
1896    }
1897}
1898
1899pub fn clamp_for_f16(xs: &Tensor) -> Result<Tensor> {
1900    let mut max = match xs.dtype() {
1901        DType::U8 => u8::MAX as f32 - 1000.,
1902        DType::U32 => u32::MAX as f32 - 1000.,
1903        DType::I16 => i16::MAX as f32 - 1000.,
1904        DType::I32 => i32::MAX as f32 - 1000.,
1905        DType::I64 => i64::MAX as f32 - 1000.,
1906        DType::F16 => half::f16::MAX.to_f32_const() - 1000.,
1907        DType::BF16 => half::bf16::MAX.to_f32_const() - 1000.,
1908        DType::F32 => f32::MAX - 1000.,
1909        DType::F64 => f64::MAX as f32 - 1000.,
1910        DType::F8E4M3 => F8E4M3::MAX.to_f32() - 1000.,
1911    };
1912    if xs.is_inf()?.any()? {
1913        max -= 1000.;
1914    }
1915    xs.clamp(-max, max)
1916}
1917
1918pub struct FloatInfo {
1919    /// Minimum representable value.
1920    pub min: f64,
1921    /// Maximum representable value.
1922    pub max: f64,
1923    /// The difference between 1.0 and the next smallest representable float larger than 1.0.
1924    pub eps: f64,
1925    pub dtype: DType,
1926}
1927
1928pub trait GetFloatInfo {
1929    fn finfo(&self) -> Result<FloatInfo>;
1930}
1931
1932impl GetFloatInfo for DType {
1933    fn finfo(&self) -> Result<FloatInfo> {
1934        let finfo = match self {
1935            Self::BF16 => FloatInfo {
1936                min: bf16::MIN.to_f64(),
1937                max: bf16::MAX.to_f64(),
1938                eps: bf16::EPSILON.to_f64(),
1939                dtype: DType::BF16,
1940            },
1941            Self::F16 => FloatInfo {
1942                min: f16::MIN.to_f64(),
1943                max: f16::MAX.to_f64(),
1944                eps: f16::EPSILON.to_f64(),
1945                dtype: DType::F16,
1946            },
1947            Self::F32 => FloatInfo {
1948                min: f32::MIN as f64,
1949                max: f32::MAX as f64,
1950                eps: f32::EPSILON as f64,
1951                dtype: DType::F32,
1952            },
1953            Self::F64 => FloatInfo {
1954                min: f64::MIN,
1955                max: f64::MAX,
1956                eps: f64::EPSILON,
1957                dtype: DType::F64,
1958            },
1959            Self::F8E4M3 => FloatInfo {
1960                min: F8E4M3::MIN.to_f64(),
1961                max: F8E4M3::MAX.to_f64(),
1962                eps: F8E4M3::EPSILON.to_f64(),
1963                dtype: DType::F8E4M3,
1964            },
1965            other => {
1966                candle_core::bail!("Expected a float type for `GetFloatInfo`, got {other:?}");
1967            }
1968        };
1969        Ok(finfo)
1970    }
1971}
1972
1973#[derive(Clone)]
1974pub struct Mlp {
1975    pub gate: Arc<dyn QuantMethod>,
1976    pub up: Arc<dyn QuantMethod>,
1977    pub down: Arc<dyn QuantMethod>,
1978    act: Activation,
1979    params: Vec<usize>,
1980}
1981
1982impl Mlp {
1983    pub fn new(
1984        vb: ShardedVarBuilder,
1985        hidden_size: usize,
1986        intermediate_size: usize,
1987        quantization_config: &Option<QuantizedConfig>,
1988        hidden_act: Activation,
1989        comm: &Arc<mistralrs_quant::Comm>,
1990    ) -> Result<Self> {
1991        Ok(Self {
1992            gate: ColumnParallelLayer::new(
1993                hidden_size,
1994                intermediate_size,
1995                quantization_config,
1996                false,
1997                comm,
1998                vb.pp("gate_proj"),
1999            )?,
2000            up: ColumnParallelLayer::new(
2001                hidden_size,
2002                intermediate_size,
2003                quantization_config,
2004                false,
2005                comm,
2006                vb.pp("up_proj"),
2007            )?,
2008            down: RowParallelLayer::new(
2009                intermediate_size,
2010                hidden_size,
2011                quantization_config,
2012                false,
2013                comm,
2014                vb.pp("down_proj"),
2015            )?,
2016            act: hidden_act,
2017            params: vec![hidden_size, intermediate_size],
2018        })
2019    }
2020
2021    pub fn replicate(
2022        params: &[usize],
2023        vb: ShardedVarBuilder,
2024        act: Activation,
2025        comm: &Arc<mistralrs_quant::Comm>,
2026    ) -> Result<Self> {
2027        Self::new(vb, params[0], params[1], &None, act, comm)
2028    }
2029
2030    pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2031        let original_dtype = xs.dtype();
2032        let mut xs = xs.clone();
2033        if let Some(t) = self.gate.quantized_act_type() {
2034            xs = xs.to_dtype(t)?;
2035        }
2036        let lhs = self.gate.forward(&xs)?;
2037        let rhs = self.up.forward(&xs)?;
2038        let mut res = self.down.forward(&candle_nn::ops::mul_and_act(
2039            &lhs,
2040            &rhs,
2041            self.act.try_into()?,
2042        )?)?;
2043        if self.gate.quantized_act_type().is_some() {
2044            res = res.to_dtype(original_dtype)?;
2045        }
2046        Ok(res)
2047    }
2048}
2049
2050impl AnyMoeTrainableLayer for Mlp {}
2051
2052impl MlpLayer for Mlp {
2053    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2054        let original_dtype = xs.dtype();
2055        let mut xs = xs.clone();
2056        if let Some(t) = self.gate.quantized_act_type() {
2057            xs = xs.to_dtype(t)?;
2058        }
2059        let lhs = MatMul.qmethod_matmul(&xs, &*self.gate)?;
2060        let rhs = MatMul.qmethod_matmul(&xs, &*self.up)?;
2061        let mut res = if matches!(
2062            self.act,
2063            Activation::Gelu | Activation::Silu | Activation::Relu
2064        ) {
2065            MatMul.qmethod_matmul(
2066                &candle_nn::ops::mul_and_act(&lhs, &rhs, self.act.try_into()?)?,
2067                &*self.down,
2068            )?
2069        } else {
2070            MatMul.qmethod_matmul(&(&lhs.apply(&self.act)? * &rhs)?, &*self.down)?
2071        };
2072        if self.gate.quantized_act_type().is_some() {
2073            res = res.to_dtype(original_dtype)?;
2074        }
2075        Ok(res)
2076    }
2077    fn get_isq_layers(&mut self) -> Vec<&mut Arc<dyn QuantMethod>> {
2078        vec![&mut self.gate, &mut self.up, &mut self.down]
2079    }
2080    fn clone(&self) -> Box<dyn MlpLayer> {
2081        Box::new(Clone::clone(self))
2082    }
2083    fn get_params(&self) -> &[usize] {
2084        &self.params
2085    }
2086    fn hidden_act(&self) -> Activation {
2087        self.act
2088    }
2089    // gate, up, down
2090    fn new_added_delta(&self, deltas: Vec<Option<Tensor>>) -> Result<Box<dyn MlpLayer>> {
2091        let gate = if let Some(ref delta) = deltas[0] {
2092            self.gate.add_delta_w(delta)?
2093        } else {
2094            self.gate.clone()
2095        };
2096        let up = if let Some(ref delta) = deltas[1] {
2097            self.up.add_delta_w(delta)?
2098        } else {
2099            self.up.clone()
2100        };
2101        let down = if let Some(ref delta) = deltas[2] {
2102            self.down.add_delta_w(delta)?
2103        } else {
2104            self.down.clone()
2105        };
2106
2107        Ok(Box::new(Self {
2108            gate,
2109            up,
2110            down,
2111            act: self.act,
2112            params: self.params.clone(),
2113        }))
2114    }
2115
2116    fn dtype_device(&self) -> (DType, Device) {
2117        self.gate.dtype_and_device()
2118    }
2119}
2120
2121pub struct AvgPool2d {
2122    kernel_size: usize,
2123    stride: usize,
2124}
2125
2126impl AvgPool2d {
2127    pub fn new(kernel_size: usize, stride: usize) -> Self {
2128        Self {
2129            kernel_size,
2130            stride,
2131        }
2132    }
2133
2134    pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2135        xs.avg_pool2d_with_stride(self.kernel_size, self.stride)
2136    }
2137}
2138
2139/// Applies 2D reflection padding to a tensor of shape (N, C, H, W).
2140///
2141/// The `padding` argument is a 4-tuple (pad_left, pad_right, pad_top, pad_bottom).
2142/// For left padding, it reflects the values from column 1 up to pad_left (in reverse order);
2143/// for right padding, it reflects from the second-to-last column backwards, and similarly for
2144/// vertical (height) padding.
2145pub struct ReflectionPad2d {
2146    padding: (usize, usize, usize, usize),
2147}
2148
2149impl ReflectionPad2d {
2150    pub fn new(padding: (usize, usize, usize, usize)) -> Self {
2151        Self { padding }
2152    }
2153}
2154
2155impl Module for ReflectionPad2d {
2156    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2157        let (pad_left, pad_right, pad_top, pad_bottom) = self.padding;
2158
2159        let (_n, _c, h, w) = xs.dims4()?;
2160
2161        // --- Horizontal Padding (along width, axis = 3) ---
2162        // For left padding, we reflect columns 1..=pad_left (in reverse order).
2163        let left_pad = if pad_left > 0 {
2164            // Create indices: [pad_left, pad_left-1, ..., 1]
2165            let indices: Vec<i64> = (1..=pad_left as i64).rev().collect();
2166            Some(xs.index_select(&Tensor::new(indices, &Device::Cpu)?, 3)?)
2167        } else {
2168            None
2169        };
2170
2171        // For right padding, we reflect from the right side (excluding the last column).
2172        let right_pad = if pad_right > 0 {
2173            // For pad_right == 2, generate indices: [w-2, w-3, ... , w-1-pad_right]
2174            let start = w as i64 - 2;
2175            let indices: Vec<i64> = (0..pad_right as i64).map(|i| start - i).collect();
2176            Some(xs.index_select(&Tensor::new(indices, &Device::Cpu)?, 3)?)
2177        } else {
2178            None
2179        };
2180
2181        // Concatenate horizontally (along width, dim=3)
2182        let x_padded_width = match (left_pad, right_pad) {
2183            (Some(l), Some(r)) => Tensor::cat(&[l, xs.clone(), r], 3)?,
2184            (Some(l), None) => Tensor::cat(&[l, xs.clone()], 3)?,
2185            (None, Some(r)) => Tensor::cat(&[xs.clone(), r], 3)?,
2186            (None, None) => xs.clone(),
2187        };
2188
2189        // --- Vertical Padding (along height, axis = 2) ---
2190        // For top padding, reflect rows 1..=pad_top (in reverse order)
2191        let top_pad = if pad_top > 0 {
2192            let indices: Vec<i64> = (1..=pad_top as i64).rev().collect();
2193            Some(x_padded_width.index_select(&Tensor::new(indices, &Device::Cpu)?, 2)?)
2194        } else {
2195            None
2196        };
2197
2198        // For bottom padding, reflect from the bottom (excluding the last row)
2199        let bottom_pad = if pad_bottom > 0 {
2200            let start = h as i64 - 2;
2201            let indices: Vec<i64> = (0..pad_bottom as i64).map(|i| start - i).collect();
2202            Some(x_padded_width.index_select(&Tensor::new(indices, &Device::Cpu)?, 2)?)
2203        } else {
2204            None
2205        };
2206
2207        // Concatenate vertically (along height, dim=2)
2208        let x_padded = match (top_pad, bottom_pad) {
2209            (Some(t), Some(b)) => Tensor::cat(&[t, x_padded_width, b], 2)?,
2210            (Some(t), None) => Tensor::cat(&[t, x_padded_width], 2)?,
2211            (None, Some(b)) => Tensor::cat(&[x_padded_width, b], 2)?,
2212            (None, None) => x_padded_width,
2213        };
2214
2215        Ok(x_padded)
2216    }
2217}
2218
2219pub struct ScaledEmbedding {
2220    scale: f64,
2221    embedding: Embedding,
2222}
2223
2224impl ScaledEmbedding {
2225    pub fn new(scale: f64, embedding: Embedding) -> Self {
2226        Self { scale, embedding }
2227    }
2228
2229    pub fn embeddings(&self) -> &Tensor {
2230        self.embedding.embeddings()
2231    }
2232}
2233
2234impl Module for ScaledEmbedding {
2235    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2236        xs.apply(&self.embedding)? * self.scale
2237    }
2238}