1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{f32::consts::PI, ops::Mul, str::FromStr, sync::Arc};
4
5use candle_core::{
6 quantized::{QMatMul, QTensor},
7 Context, DType, Device, IndexOp, Result, Tensor, D,
8};
9use candle_nn::{
10 Conv2d, Conv2dConfig, Embedding, GroupNorm, LayerNorm, LayerNormConfig, Linear, Module,
11};
12use float8::F8E4M3;
13use half::{bf16, f16};
14use mistralrs_quant::{
15 AfqLayer, ColumnParallelLayer, QuantMethod, QuantizedConfig, RowParallelLayer,
16 ShardedVarBuilder,
17};
18use serde::{Deserialize, Serialize};
19
20pub use crate::attention::Sdpa;
21pub use crate::layers_masker::CausalMasker;
22pub use crate::layers_utils::repeat_kv;
23use crate::{
24 amoe::{AnyMoeTrainableLayer, MlpLayer},
25 gguf::Content,
26 models::llama,
27 ops::SplitOp,
28 vision_models::{
29 gemma3::config::Gemma3TextConfig,
30 llama4,
31 mllama::{MLlamaRopeScaling, MLlamaRopeType, MLlamaTextConfig},
32 phi4::Phi4MMConfig,
33 },
34};
35
36pub use mistralrs_quant::MatMul;
37
38pub fn embedding(
39 in_size: usize,
40 out_size: usize,
41 vb: ShardedVarBuilder,
42 config: &Option<QuantizedConfig>,
43) -> Result<Embedding> {
44 let embeddings = if let Some(QuantizedConfig::Afq { .. }) = config {
46 let afq_layer =
47 AfqLayer::afq_linear_b(out_size, in_size, config.as_ref().unwrap(), false, vb)?;
48 afq_layer.dequantize_w()?
49 } else {
50 vb.get_with_hints((in_size, out_size), "weight", Default::default())?
51 };
52 Ok(Embedding::new(embeddings, out_size))
53}
54
55pub fn layer_norm<C: Into<LayerNormConfig>>(
56 size: usize,
57 config: C,
58 vb: ShardedVarBuilder,
59) -> Result<LayerNorm> {
60 let config = config.into();
61 let weight = vb.get(size, "weight")?;
62 if config.affine {
63 let bias = vb.get(size, "bias")?;
64 Ok(LayerNorm::new(weight, bias, config.eps))
65 } else {
66 Ok(LayerNorm::new_no_bias(weight, config.eps))
67 }
68}
69
70pub fn group_norm(
71 num_groups: usize,
72 num_channels: usize,
73 eps: f64,
74 vb: ShardedVarBuilder,
75) -> Result<GroupNorm> {
76 let weight = vb.get(num_channels, "weight")?;
77 let bias = vb.get(num_channels, "bias")?;
78 GroupNorm::new(weight, bias, num_channels, num_groups, eps)
79}
80
81pub fn conv2d(
82 in_channels: usize,
83 out_channels: usize,
84 kernel_size: usize,
85 cfg: Conv2dConfig,
86 vb: ShardedVarBuilder,
87) -> Result<Conv2d> {
88 let ws = vb.get(
89 (
90 out_channels,
91 in_channels / cfg.groups,
92 kernel_size,
93 kernel_size,
94 ),
95 "weight",
96 )?;
97 let bs = vb.get(out_channels, "bias")?;
98 Ok(Conv2d::new(ws, Some(bs), cfg))
99}
100
101pub fn conv2d_no_bias(
102 in_channels: usize,
103 out_channels: usize,
104 kernel_size: usize,
105 cfg: Conv2dConfig,
106 vb: ShardedVarBuilder,
107) -> Result<Conv2d> {
108 let ws = vb.get(
109 (
110 out_channels,
111 in_channels / cfg.groups,
112 kernel_size,
113 kernel_size,
114 ),
115 "weight",
116 )?;
117 Ok(Conv2d::new(ws, None, cfg))
118}
119
120pub fn linear(in_dim: usize, out_dim: usize, vb: ShardedVarBuilder) -> Result<Linear> {
121 let ws = vb.get((out_dim, in_dim), "weight")?;
122 let bs = vb.get(out_dim, "bias")?;
123 Ok(Linear::new(ws, Some(bs)))
124}
125
126pub fn linear_no_bias(in_dim: usize, out_dim: usize, vb: ShardedVarBuilder) -> Result<Linear> {
127 let ws = vb.get((out_dim, in_dim), "weight")?;
128 Ok(Linear::new(ws, None))
129}
130
131pub fn linear_b(
132 in_dim: usize,
133 out_dim: usize,
134 bias: bool,
135 vb: ShardedVarBuilder,
136) -> Result<Linear> {
137 if bias {
138 linear(in_dim, out_dim, vb)
139 } else {
140 linear_no_bias(in_dim, out_dim, vb)
141 }
142}
143
144#[derive(Debug, Clone)]
145pub struct RmsNorm {
146 eps: f64,
147 weight: Tensor,
148}
149
150impl RmsNorm {
151 pub fn new(size: usize, eps: f64, vb: ShardedVarBuilder) -> Result<Self> {
152 let w = vb.get(size, "weight")?;
153 Ok(Self { eps, weight: w })
154 }
155
156 pub fn new_gemma(size: usize, eps: f64, vb: ShardedVarBuilder) -> Result<Self> {
158 let w = vb.get(size, "weight")?;
159 let w = (w + 1.0)?;
160 Ok(Self { eps, weight: w })
161 }
162
163 pub fn undo_gemma(&self) -> Result<Self> {
165 Ok(Self {
166 eps: self.eps,
167 weight: (&self.weight - 1.0)?,
168 })
169 }
170
171 pub fn from_w(w: Tensor, eps: f64) -> Result<Self> {
172 Ok(Self { eps, weight: w })
173 }
174
175 pub fn weight(&self) -> &Tensor {
176 &self.weight
177 }
178}
179
180impl Module for RmsNorm {
181 fn forward(&self, x: &Tensor) -> Result<Tensor> {
182 candle_nn::ops::rms_norm(&x.contiguous()?, &self.weight, self.eps as f32)
183 }
184}
185
186#[derive(Debug, Clone)]
187pub struct F32RmsNorm {
188 w: Tensor,
189 eps: f64,
190}
191
192impl F32RmsNorm {
193 pub fn new(size: usize, eps: f64, vb: ShardedVarBuilder) -> Result<Self> {
194 Ok(Self {
195 w: vb.get((size,), "weight")?,
196 eps,
197 })
198 }
199
200 pub fn weight(&self) -> &Tensor {
201 &self.w
202 }
203}
204
205impl Module for F32RmsNorm {
206 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
207 let initial_type = xs.dtype();
208 let mut xs = xs.to_dtype(DType::F32)?;
209 let var = xs.powf(2.)?.mean_keepdim(D::Minus1)?;
210 xs = xs.broadcast_mul(&(&var + self.eps)?.recip()?.sqrt()?)?;
211 xs.to_dtype(initial_type)?.broadcast_mul(&self.w)
212 }
213}
214
215#[derive(Debug, Clone)]
216pub struct QRmsNorm {
217 eps: f64,
218 weight: Tensor,
219}
220
221impl QRmsNorm {
222 pub fn new(scale: QTensor, eps: f32) -> Result<Self> {
223 let scale = scale.dequantize(&scale.device())?;
224 Ok(Self {
225 eps: eps as f64,
226 weight: scale,
227 })
228 }
229
230 pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
231 candle_nn::ops::rms_norm(&x.contiguous()?, &self.weight, self.eps as f32)
232 }
233}
234
235#[derive(Debug, Clone)]
237pub struct PhiRotaryEmbedding {
238 short_sin: Tensor,
239 short_cos: Tensor,
240 long_cos: Option<Tensor>,
241 long_sin: Option<Tensor>,
242 original_max_position_embeddings: usize,
243}
244
245#[derive(Debug, Clone, Deserialize, Serialize)]
246#[serde(rename_all = "lowercase")]
247pub enum ScaledRopeType {
248 #[serde(alias = "su")]
249 #[serde(alias = "longrope")]
250 Su,
251 #[serde(alias = "yarn")]
252 Yarn,
253 #[serde(alias = "dynamic")]
254 Dynamic,
255 #[serde(alias = "linear")]
256 Linear,
257}
258
259impl FromStr for ScaledRopeType {
260 type Err = candle_core::Error;
261 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
262 match s {
263 "su" | "longrope" => Ok(Self::Su),
264 "yarn" => Ok(Self::Yarn),
265 "linear" => Ok(Self::Linear),
266 "dynamic" => Ok(Self::Dynamic),
267 _ => Err(candle_core::Error::Msg(
268 "Expected either `su` or `yarn` scaled RoPE type.".to_string(),
269 )),
270 }
271 }
272}
273
274#[derive(Debug, Clone, Deserialize, Serialize)]
275#[serde(untagged)]
276pub enum PhiRopeScalingConfig {
277 Classic {
278 short_factor: Vec<f64>,
279 long_factor: Vec<f64>,
280 #[serde(rename = "type")]
281 scaling_type: ScaledRopeType,
282 },
283 Scaled {
284 short_factor: Vec<f64>,
285 long_factor: Vec<f64>,
286 #[serde(rename = "type")]
287 scaling_type: ScaledRopeType,
288 long_mscale: f64,
289 short_mscale: f64,
290 },
291}
292
293pub struct PhiRopeConfig {
294 pub rope_scaling: Option<PhiRopeScalingConfig>,
295 pub max_position_embeddings: usize,
296 pub original_max_position_embeddings: usize,
297 pub rope_theta: f64,
298 pub head_dim: usize,
299 pub partial_rotary_factor: Option<f64>,
300}
301
302impl PhiRotaryEmbedding {
303 fn new_classic_scaled(
304 short_factor: &[f64],
305 long_factor: &[f64],
306 scaling_type: &ScaledRopeType,
307 cfg: &PhiRopeConfig,
308 dtype: DType,
309 dev: &Device,
310 ) -> Result<Self> {
311 let max_seq_len = cfg.max_position_embeddings;
312 let dim = (cfg.head_dim as f64 * cfg.partial_rotary_factor.unwrap_or(1.)) as usize;
313
314 let scale =
316 cfg.max_position_embeddings as f64 / cfg.original_max_position_embeddings as f64;
317 let scaling_factor = if scale <= 1.0 {
318 1.0
319 } else {
320 match scaling_type {
321 ScaledRopeType::Su => {
322 (1.0 + scale.ln() / (cfg.original_max_position_embeddings as f64).ln()).sqrt()
323 }
324 ScaledRopeType::Yarn => 0.1 * scale.ln() + 1.0,
325 _ => candle_core::bail!("Expected either `su` or `yarn` RoPE"),
326 }
327 };
328
329 let inv_freq_long = (0..dim)
331 .step_by(2)
332 .enumerate()
333 .map(|(k, i)| {
334 (1f64 / (long_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64))) as f32
335 })
336 .collect::<Vec<_>>();
337 let inv_freq_short = (0..dim)
338 .step_by(2)
339 .enumerate()
340 .map(|(k, i)| {
341 (1f64 / (short_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64))) as f32
342 })
343 .collect::<Vec<_>>();
344 let inv_freq_len = inv_freq_long.len();
345
346 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
347 .to_dtype(DType::F32)?
348 .reshape((max_seq_len, 1))?;
349
350 let inv_freq_long = Tensor::from_vec(inv_freq_long, (1, inv_freq_len), dev)?;
352 let freqs_long = t.matmul(&inv_freq_long)?;
353 let long_sin = freqs_long.sin()?.mul(scaling_factor)?.to_dtype(dtype)?;
354 let long_cos = freqs_long.cos()?.mul(scaling_factor)?.to_dtype(dtype)?;
355
356 let inv_freq_short =
358 Tensor::from_vec(inv_freq_short, (1, inv_freq_len), dev)?.to_dtype(DType::F32)?;
359 let freqs_short = t.matmul(&inv_freq_short)?;
360 let short_sin = freqs_short.sin()?.mul(scaling_factor)?.to_dtype(dtype)?;
361 let short_cos = freqs_short.cos()?.mul(scaling_factor)?.to_dtype(dtype)?;
362
363 Ok(Self {
364 short_cos,
365 short_sin,
366 long_cos: Some(long_cos),
367 long_sin: Some(long_sin),
368 original_max_position_embeddings: cfg.original_max_position_embeddings,
369 })
370 }
371
372 fn new_unscaled(cfg: &PhiRopeConfig, dtype: DType, dev: &Device) -> Result<Self> {
373 let max_seq_len = cfg.max_position_embeddings;
374 let dim = (cfg.head_dim as f64 * cfg.partial_rotary_factor.unwrap_or(1.)) as usize;
375
376 let inv_freq: Vec<_> = (0..dim)
377 .step_by(2)
378 .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
379 .collect();
380 let inv_freq_len = inv_freq.len();
381 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
382 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
383 .to_dtype(DType::F32)?
384 .reshape((max_seq_len, 1))?;
385 let freqs = t.matmul(&inv_freq)?;
386 let sin = freqs.sin()?.to_dtype(dtype)?;
387 let cos = freqs.cos()?.to_dtype(dtype)?;
388 Ok(Self {
389 short_cos: cos,
390 short_sin: sin,
391 long_cos: None,
392 long_sin: None,
393 original_max_position_embeddings: cfg.original_max_position_embeddings,
394 })
395 }
396
397 #[allow(clippy::too_many_arguments)]
398 fn new_scaled(
399 short_factor: &[f64],
400 long_factor: &[f64],
401 scaling_type: &ScaledRopeType,
402 long_mscale: f64,
403 short_mscale: f64,
404 cfg: &PhiRopeConfig,
405 dtype: DType,
406 dev: &Device,
407 ) -> Result<Self> {
408 let max_seq_len = cfg.max_position_embeddings;
409 let dim = (cfg.head_dim as f64 * cfg.partial_rotary_factor.unwrap_or(1.)) as usize;
410
411 if !matches!(scaling_type, ScaledRopeType::Su) {
412 candle_core::bail!("Scaled Phi3 RoPE (non-classic scaled, with mscales) must have type `su`/`longrope`.");
413 }
414
415 if short_factor.len() != dim / 2 {
416 candle_core::bail!(
417 "Misaligned length {}, expected {} for `su`/`longrope` short rescale factors",
418 short_factor.len(),
419 dim / 2
420 );
421 }
422 if long_factor.len() != dim / 2 {
423 candle_core::bail!(
424 "Misaligned length {}, expected {} for `su`/`longrope` long rescale factors",
425 long_factor.len(),
426 dim / 2
427 );
428 }
429
430 let inv_freq_short: Vec<_> = (0..dim)
432 .step_by(2)
433 .enumerate()
434 .map(|(k, i)| {
435 1f32 / (short_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32
436 })
437 .collect();
438 let inv_freq_len_short = inv_freq_short.len();
439 let inv_freq_short = Tensor::from_vec(inv_freq_short, (1, inv_freq_len_short), dev)?;
440 let t_short = Tensor::arange(0u32, max_seq_len as u32, dev)?
441 .to_dtype(DType::F32)?
442 .reshape((max_seq_len, 1))?;
443 let freqs_short = t_short.matmul(&inv_freq_short)?;
444 let sin_short = (freqs_short.sin()?.to_dtype(dtype)? * short_mscale)?;
445 let cos_short = (freqs_short.cos()?.to_dtype(dtype)? * short_mscale)?;
446
447 let inv_freq_long: Vec<_> = (0..dim)
449 .step_by(2)
450 .enumerate()
451 .map(|(k, i)| {
452 1f32 / (long_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32
453 })
454 .collect();
455 let inv_freq_len_long = inv_freq_long.len();
456 let inv_freq_long = Tensor::from_vec(inv_freq_long, (1, inv_freq_len_long), dev)?;
457 let t_long = Tensor::arange(0u32, max_seq_len as u32, dev)?
458 .to_dtype(DType::F32)?
459 .reshape((max_seq_len, 1))?;
460 let freqs_long = t_long.matmul(&inv_freq_long)?;
461 let sin_long = (freqs_long.sin()?.to_dtype(dtype)? * long_mscale)?;
462 let cos_long = (freqs_long.cos()?.to_dtype(dtype)? * long_mscale)?;
463 Ok(Self {
464 short_cos: cos_short,
465 short_sin: sin_short,
466 long_cos: Some(cos_long),
467 long_sin: Some(sin_long),
468 original_max_position_embeddings: cfg.original_max_position_embeddings,
469 })
470 }
471
472 pub fn new(dtype: DType, cfg: impl Into<PhiRopeConfig>, dev: &Device) -> Result<Self> {
473 let cfg: PhiRopeConfig = cfg.into();
474
475 match &cfg.rope_scaling {
476 Some(PhiRopeScalingConfig::Classic {
477 short_factor,
478 long_factor,
479 scaling_type,
480 }) => {
481 Self::new_classic_scaled(short_factor, long_factor, scaling_type, &cfg, dtype, dev)
482 }
483
484 Some(PhiRopeScalingConfig::Scaled {
485 short_factor,
486 long_factor,
487 scaling_type,
488 long_mscale,
489 short_mscale,
490 }) => Self::new_scaled(
491 short_factor,
492 long_factor,
493 scaling_type,
494 *long_mscale,
495 *short_mscale,
496 &cfg,
497 dtype,
498 dev,
499 ),
500
501 None => Self::new_unscaled(&cfg, dtype, dev),
502 }
503 }
504
505 fn get_long_or_short_sin_cos(&self, position_ids: &[usize]) -> (&Tensor, &Tensor) {
507 if self.long_cos.is_none() {
508 return (&self.short_sin, &self.short_cos);
509 }
510 let seq_len = position_ids.iter().max().unwrap() + 1;
511 if seq_len > self.original_max_position_embeddings {
512 (
513 self.long_sin.as_ref().unwrap(),
514 self.long_cos.as_ref().unwrap(),
515 )
516 } else {
517 (&self.short_sin, &self.short_cos)
518 }
519 }
520
521 pub fn forward(
522 &self,
523 q: &Tensor,
524 k: &Tensor,
525 seqlen_offsets: &[usize],
526 position_ids: &[usize],
527 ) -> Result<(Tensor, Tensor)> {
528 let (sin, cos) = self.get_long_or_short_sin_cos(position_ids);
529 let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
530
531 let rot_dim = cos.dim(D::Minus1)? * 2;
532
533 if rot_dim != q.dim(D::Minus1)? {
535 let rot_dim = cos.dim(D::Minus1)? * 2;
536 let q_rot = q.narrow(D::Minus1, 0, rot_dim)?;
537 let q_pass = q.narrow(D::Minus1, rot_dim, q.dim(D::Minus1)? - rot_dim)?;
538 let k_rot = k.narrow(D::Minus1, 0, rot_dim)?;
539 let k_pass = k.narrow(D::Minus1, rot_dim, k.dim(D::Minus1)? - rot_dim)?;
540
541 let (q_rot, k_rot) = if seqlen_offsets.len() == 1 {
542 let cos = cos.narrow(0, seqlen_offsets[0], seq_len)?;
543 let sin = sin.narrow(0, seqlen_offsets[0], seq_len)?;
544 let q_embed = candle_nn::rotary_emb::rope(&q_rot.contiguous()?, &cos, &sin)?;
545 let k_embed = candle_nn::rotary_emb::rope(&k_rot.contiguous()?, &cos, &sin)?;
546 (q_embed, k_embed)
547 } else {
548 let mut q_embeds = Vec::new();
549 let mut k_embeds = Vec::new();
550 for (i, offset) in seqlen_offsets.iter().enumerate() {
551 let cos = cos.narrow(0, *offset, seq_len)?;
552 let sin = sin.narrow(0, *offset, seq_len)?;
553 let q_embed = candle_nn::rotary_emb::rope(
554 &q_rot.i(i)?.unsqueeze(0)?.contiguous()?,
555 &cos,
556 &sin,
557 )?;
558 let k_embed = candle_nn::rotary_emb::rope(
559 &k_rot.i(i)?.unsqueeze(0)?.contiguous()?,
560 &cos,
561 &sin,
562 )?;
563 q_embeds.push(q_embed);
564 k_embeds.push(k_embed);
565 }
566 let q_rot = Tensor::cat(&q_embeds, 0)?;
567 let k_rot = Tensor::cat(&k_embeds, 0)?;
568 (q_rot, k_rot)
569 };
570
571 Ok((
572 Tensor::cat(&[q_rot, q_pass], D::Minus1)?.contiguous()?,
573 Tensor::cat(&[k_rot, k_pass], D::Minus1)?.contiguous()?,
574 ))
575 } else if seqlen_offsets.len() == 1 {
576 let cos = cos.narrow(0, seqlen_offsets[0], seq_len)?;
577 let sin = sin.narrow(0, seqlen_offsets[0], seq_len)?;
578 let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
579 let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
580 Ok((q_embed, k_embed))
581 } else {
582 let mut q_embeds = Vec::new();
583 let mut k_embeds = Vec::new();
584 for (i, offset) in seqlen_offsets.iter().enumerate() {
585 let cos = cos.narrow(0, *offset, seq_len)?;
586 let sin = sin.narrow(0, *offset, seq_len)?;
587 let q_embed =
588 candle_nn::rotary_emb::rope(&q.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
589 let k_embed =
590 candle_nn::rotary_emb::rope(&k.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
591 q_embeds.push(q_embed);
592 k_embeds.push(k_embed);
593 }
594 Ok((Tensor::cat(&q_embeds, 0)?, Tensor::cat(&k_embeds, 0)?))
595 }
596 }
597}
598
599#[derive(Debug, Clone)]
601pub struct Llama3RotaryEmbedding(RotaryEmbedding);
602
603#[derive(Debug, Clone, Deserialize, Serialize, Default)]
604pub enum Llama3RopeType {
605 #[serde(rename = "llama3")]
606 Llama3,
607 #[default]
608 #[serde(rename = "default")]
609 Default,
610}
611
612#[derive(Debug, Clone, Deserialize, Serialize, Default)]
613pub struct Llama3RopeConfig {
614 pub factor: f32,
615 pub low_freq_factor: f32,
616 pub high_freq_factor: f32,
617 pub original_max_position_embeddings: usize,
618 pub rope_type: Llama3RopeType,
619}
620
621fn calculate_default_inv_freq(cfg: &llama::Config) -> Vec<f32> {
622 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
623 (0..head_dim)
624 .step_by(2)
625 .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
626 .collect()
627}
628
629fn calculate_default_inv_freq_llama4(cfg: &llama4::TextConfig) -> Vec<f32> {
630 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
631 (0..head_dim)
632 .step_by(2)
633 .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
634 .collect()
635}
636
637impl Llama3RotaryEmbedding {
639 pub fn new_llama3(
640 dtype: DType,
641 cfg: &llama::Config,
642 dev: &Device,
643 is_gpt_neox: bool,
644 ) -> Result<Self> {
645 match &cfg.rope_scaling {
646 None
647 | Some(Llama3RopeConfig {
648 rope_type: Llama3RopeType::Default,
649 ..
650 }) => Ok(Self(RotaryEmbedding::new(
651 cfg.rope_theta,
652 cfg.hidden_size / cfg.num_attention_heads,
653 cfg.max_position_embeddings,
654 dev,
655 is_gpt_neox,
656 dtype,
657 )?)),
658 Some(rope_scaling) => {
659 let low_freq_wavelen = rope_scaling.original_max_position_embeddings as f32
660 / rope_scaling.low_freq_factor;
661 let high_freq_wavelen = rope_scaling.original_max_position_embeddings as f32
662 / rope_scaling.high_freq_factor;
663
664 let inv_freq = calculate_default_inv_freq(cfg)
665 .into_iter()
666 .map(|freq| {
667 let wavelen = 2. * PI / freq;
668 if wavelen < high_freq_wavelen {
669 freq
670 } else if wavelen > low_freq_wavelen {
671 freq / rope_scaling.factor
672 } else {
673 let smooth = (rope_scaling.original_max_position_embeddings as f32
674 / wavelen
675 - rope_scaling.low_freq_factor)
676 / (rope_scaling.high_freq_factor - rope_scaling.low_freq_factor);
677 (1. - smooth) * freq / rope_scaling.factor + smooth * freq
678 }
679 })
680 .collect::<Vec<_>>();
681 let inv_freq_len = inv_freq.len();
682 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
683
684 let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
685 .to_dtype(DType::F32)?
686 .reshape((cfg.max_position_embeddings, 1))?;
687 let freqs = t.matmul(&inv_freq)?;
688 let sin = freqs.sin()?.to_dtype(dtype)?;
689 let cos = freqs.cos()?.to_dtype(dtype)?;
690 Ok(Self(RotaryEmbedding {
691 sin,
692 cos,
693 is_gpt_neox,
694 }))
695 }
696 }
697 }
698
699 pub fn new_llama4(
700 dtype: DType,
701 cfg: &llama4::TextConfig,
702 dev: &Device,
703 is_gpt_neox: bool,
704 ) -> Result<Self> {
705 match &cfg.rope_scaling {
706 None
707 | Some(Llama3RopeConfig {
708 rope_type: Llama3RopeType::Default,
709 ..
710 }) => Ok(Self(RotaryEmbedding::new(
711 cfg.rope_theta,
712 cfg.hidden_size / cfg.num_attention_heads,
713 cfg.max_position_embeddings,
714 dev,
715 is_gpt_neox,
716 dtype,
717 )?)),
718 Some(rope_scaling) => {
719 let low_freq_wavelen = rope_scaling.original_max_position_embeddings as f32
720 / rope_scaling.low_freq_factor;
721 let high_freq_wavelen = rope_scaling.original_max_position_embeddings as f32
722 / rope_scaling.high_freq_factor;
723
724 let inv_freq = calculate_default_inv_freq_llama4(cfg)
725 .into_iter()
726 .map(|freq| {
727 let wavelen = 2. * PI / freq;
728 if wavelen < high_freq_wavelen {
729 freq
730 } else if wavelen > low_freq_wavelen {
731 freq / rope_scaling.factor
732 } else {
733 let smooth = (rope_scaling.original_max_position_embeddings as f32
734 / wavelen
735 - rope_scaling.low_freq_factor)
736 / (rope_scaling.high_freq_factor - rope_scaling.low_freq_factor);
737 (1. - smooth) * freq / rope_scaling.factor + smooth * freq
738 }
739 })
740 .collect::<Vec<_>>();
741 let inv_freq_len = inv_freq.len();
742 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
743
744 let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
745 .to_dtype(DType::F32)?
746 .reshape((cfg.max_position_embeddings, 1))?;
747 let freqs = t.matmul(&inv_freq)?;
748 let sin = freqs.sin()?.to_dtype(dtype)?;
749 let cos = freqs.cos()?.to_dtype(dtype)?;
750 Ok(Self(RotaryEmbedding {
751 sin,
752 cos,
753 is_gpt_neox,
754 }))
755 }
756 }
757 }
758
759 pub fn new_mllama3(
760 dtype: DType,
761 cfg: &MLlamaTextConfig,
762 dev: &Device,
763 is_gpt_neox: bool,
764 ) -> Result<Self> {
765 match &cfg.rope_scaling {
766 None
767 | Some(MLlamaRopeScaling {
768 rope_type: MLlamaRopeType::Default,
769 ..
770 }) => Ok(Self(RotaryEmbedding::new(
771 cfg.rope_theta,
772 cfg.hidden_size / cfg.num_attention_heads,
773 cfg.max_position_embeddings,
774 dev,
775 is_gpt_neox,
776 dtype,
777 )?)),
778 Some(MLlamaRopeScaling {
779 rope_type: MLlamaRopeType::Llama3,
780 original_max_position_embeddings,
781 factor,
782 attention_factor: _,
783 beta_fast: _,
784 beta_slow: _,
785 short_factor: _,
786 long_factor: _,
787 low_freq_factor,
788 high_freq_factor,
789 }) => {
790 let factor = factor.context("MLlama Llama3 RoPE needs `factor` parameter.")?;
791 let low_freq_factor = low_freq_factor
792 .context("MLlama Llama3 RoPE needs `low_freq_factor` parameter.")?;
793 let high_freq_factor = high_freq_factor
794 .context("MLlama Llama3 RoPE needs `high_freq_factor` parameter.")?;
795
796 let low_freq_wavelen = *original_max_position_embeddings as f32 / low_freq_factor;
797 let high_freq_wavelen = *original_max_position_embeddings as f32 / high_freq_factor;
798
799 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
800
801 let inv_freq = (0..head_dim)
802 .step_by(2)
803 .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
804 .map(|freq| {
805 let wavelen = 2. * PI / freq;
806 if wavelen < high_freq_wavelen {
807 freq
808 } else if wavelen > low_freq_wavelen {
809 freq / factor
810 } else {
811 let smooth = (*original_max_position_embeddings as f32 / wavelen
812 - low_freq_factor)
813 / (high_freq_factor - low_freq_factor);
814 (1. - smooth) * freq / factor + smooth * freq
815 }
816 })
817 .collect::<Vec<_>>();
818 let inv_freq_len = inv_freq.len();
819 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
820
821 let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
822 .to_dtype(DType::F32)?
823 .reshape((cfg.max_position_embeddings, 1))?;
824 let freqs = t.matmul(&inv_freq)?;
825 let sin = freqs.sin()?.to_dtype(dtype)?;
826 let cos = freqs.cos()?.to_dtype(dtype)?;
827 Ok(Self(RotaryEmbedding {
828 sin,
829 cos,
830 is_gpt_neox,
831 }))
832 }
833 Some(MLlamaRopeScaling {
834 rope_type: other, ..
835 }) => {
836 candle_core::bail!(
837 "MLlama doesn't support any other RoPE type than `llama3`, got {other:?}"
838 )
839 }
840 }
841 }
842
843 pub fn forward(
844 &self,
845 q: &Tensor,
846 k: &Tensor,
847 seqlen_offsets: &[usize],
848 ) -> Result<(Tensor, Tensor)> {
849 self.0.forward(q, k, seqlen_offsets)
850 }
851}
852
853#[derive(Debug, Clone)]
855pub struct Qwen2VLRotaryEmbedding {
856 inv_freq: Tensor,
857 mrope_section: Vec<usize>,
858}
859
860impl Qwen2VLRotaryEmbedding {
861 pub fn new(
862 base: f32,
863 head_dim: usize,
864 device: &Device,
865 mrope_section: Vec<usize>,
866 ) -> Result<Self> {
867 let inv_freq: Vec<_> = (0..head_dim)
868 .step_by(2)
869 .map(|i| 1f32 / base.powf(i as f32 / head_dim as f32))
870 .collect();
871 let inv_freq_len = inv_freq.len();
872 let inv_freq = Tensor::from_vec(inv_freq, (inv_freq_len,), device)?.to_dtype(DType::F32)?;
873 Ok(Self {
874 inv_freq,
875 mrope_section,
876 })
877 }
878
879 pub fn compute_cos_sin(&self, position_ids: &Tensor, dtype: DType) -> Result<(Tensor, Tensor)> {
881 let inv_freq_expanded =
882 self.inv_freq
883 .reshape((1, 1, (), 1))?
884 .repeat((3, position_ids.dim(1)?, 1, 1))?;
885 let position_ids_expanded = position_ids.unsqueeze(2)?;
886 let freqs = inv_freq_expanded
887 .matmul(&position_ids_expanded.to_dtype(inv_freq_expanded.dtype())?)?
888 .transpose(2, 3)?;
889 let cos = freqs.cos()?;
890 let sin = freqs.sin()?;
891
892 let cos = Tensor::cat(
893 &cos.split(&self.mrope_section, D::Minus1)?
894 .into_iter()
895 .enumerate()
896 .map(|(i, m)| m.i(i % 3))
897 .collect::<Result<Vec<_>>>()?,
898 D::Minus1,
899 )?
900 .squeeze(0)?
901 .to_dtype(dtype)?
902 .contiguous()?;
903 let sin = Tensor::cat(
904 &sin.split(&self.mrope_section, D::Minus1)?
905 .into_iter()
906 .enumerate()
907 .map(|(i, m)| m.i(i % 3))
908 .collect::<Result<Vec<_>>>()?,
909 D::Minus1,
910 )?
911 .squeeze(0)?
912 .to_dtype(dtype)?
913 .contiguous()?;
914
915 Ok((cos, sin))
916 }
917
918 pub fn forward(
920 &self,
921 (cos, sin): &(Tensor, Tensor),
922 q: &mut Tensor,
923 k: &mut Tensor,
924 ) -> Result<()> {
925 *q = candle_nn::rotary_emb::rope(&q.contiguous()?, cos, sin)?;
926 *k = candle_nn::rotary_emb::rope(&k.contiguous()?, cos, sin)?;
927 Ok(())
928 }
929}
930
931#[derive(Debug, Clone)]
932pub struct Qwen2_5VLRotaryEmbedding {
933 inv_freq: Tensor,
934 mrope_section: Vec<usize>,
935}
936
937impl Qwen2_5VLRotaryEmbedding {
938 pub fn new(
939 base: f32,
940 head_dim: usize,
941 device: &Device,
942 mrope_section: Vec<usize>,
943 ) -> Result<Self> {
944 let inv_freq: Vec<_> = (0..head_dim)
945 .step_by(2)
946 .map(|i| 1f32 / base.powf(i as f32 / head_dim as f32))
947 .collect();
948 let inv_freq_len = inv_freq.len();
949 let inv_freq = Tensor::from_vec(inv_freq, (inv_freq_len,), device)?.to_dtype(DType::F32)?;
950 Ok(Self {
951 inv_freq,
952 mrope_section,
953 })
954 }
955
956 pub fn compute_cos_sin(&self, position_ids: &Tensor, dtype: DType) -> Result<(Tensor, Tensor)> {
958 let inv_freq_expanded =
959 self.inv_freq
960 .reshape((1, 1, (), 1))?
961 .repeat((3, position_ids.dim(1)?, 1, 1))?;
962 let position_ids_expanded = position_ids.unsqueeze(2)?;
963 let freqs = inv_freq_expanded
964 .matmul(&position_ids_expanded.to_dtype(inv_freq_expanded.dtype())?)?
965 .transpose(2, 3)?;
966 let cos = freqs.cos()?;
967 let sin = freqs.sin()?;
968
969 let cos = Tensor::cat(
970 &cos.split(&self.mrope_section, D::Minus1)?
971 .into_iter()
972 .enumerate()
973 .map(|(i, m)| m.i(i % 3))
974 .collect::<Result<Vec<_>>>()?,
975 D::Minus1,
976 )?
977 .squeeze(0)?
978 .to_dtype(dtype)?
979 .contiguous()?;
980 let sin = Tensor::cat(
981 &sin.split(&self.mrope_section, D::Minus1)?
982 .into_iter()
983 .enumerate()
984 .map(|(i, m)| m.i(i % 3))
985 .collect::<Result<Vec<_>>>()?,
986 D::Minus1,
987 )?
988 .squeeze(0)?
989 .to_dtype(dtype)?
990 .contiguous()?;
991
992 Ok((cos, sin))
993 }
994
995 pub fn forward(
996 &self,
997 (cos, sin): &(Tensor, Tensor),
998 q: &mut Tensor,
999 k: &mut Tensor,
1000 ) -> Result<()> {
1001 *q = candle_nn::rotary_emb::rope(&q.contiguous()?, cos, sin)?;
1002 *k = candle_nn::rotary_emb::rope(&k.contiguous()?, cos, sin)?;
1003 Ok(())
1004 }
1005}
1006
1007#[derive(Debug, Clone)]
1008pub struct DeepSeekV2RotaryEmbedding {
1009 sin: Tensor,
1010 cos: Tensor,
1011}
1012
1013#[derive(Debug, Clone, Deserialize, Serialize)]
1014#[serde(untagged)]
1015pub enum DeepSeekV2RopeScaling {
1016 Yarn {
1017 original_max_position_embeddings: usize,
1018 beta_fast: f32,
1019 beta_slow: f32,
1020 mscale: f32,
1021 mscale_all_dim: f32,
1022 factor: f32,
1023 #[serde(rename = "type")]
1024 scaling_type: ScaledRopeType,
1025 },
1026 LinearOrDynamic {
1027 #[serde(rename = "type")]
1028 scaling_type: ScaledRopeType,
1029 factor: f64,
1030 },
1031}
1032
1033pub struct DeepSeekV2RopeConfig {
1034 pub rope_scaling: Option<DeepSeekV2RopeScaling>,
1035 pub max_position_embeddings: usize,
1036 pub rope_theta: f32,
1037 pub qk_rope_head_dim: usize,
1038}
1039
1040impl DeepSeekV2RotaryEmbedding {
1041 fn new_unscaled(cfg: &DeepSeekV2RopeConfig, dtype: DType, dev: &Device) -> Result<Self> {
1042 let max_seq_len = cfg.max_position_embeddings;
1043 let dim = cfg.qk_rope_head_dim;
1044
1045 let inv_freq: Vec<_> = (0..dim)
1046 .step_by(2)
1047 .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / dim as f32))
1048 .collect();
1049 let inv_freq_len = inv_freq.len();
1050 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
1051 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
1052 .to_dtype(DType::F32)?
1053 .reshape((max_seq_len, 1))?;
1054 let freqs = t.matmul(&inv_freq)?;
1055
1056 let sin = freqs.sin()?.to_dtype(dtype)?;
1057 let cos = freqs.cos()?.to_dtype(dtype)?;
1058
1059 Ok(Self { sin, cos })
1060 }
1061
1062 fn yarn_find_correction_dim(
1063 num_rot: f32,
1064 dim: usize,
1065 base: f32,
1066 max_position_embeddings: usize,
1067 ) -> f32 {
1068 (dim as f32 * (max_position_embeddings as f32 / (num_rot * 2. * PI)).ln())
1069 / (2. * base.ln())
1070 }
1071
1072 fn yarn_find_correction_range(
1073 low_rot: f32,
1074 high_rot: f32,
1075 dim: usize,
1076 base: f32,
1077 max_position_embeddings: usize,
1078 ) -> (f32, f32) {
1079 let low =
1080 Self::yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings).floor();
1081 let high =
1082 Self::yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings).ceil();
1083 (low.max(0.), high.min(dim as f32 - 1.))
1084 }
1085
1086 fn yarn_linear_ramp_mask(min: f32, mut max: f32, dim: usize, dev: &Device) -> Result<Tensor> {
1087 if min == max {
1088 max += 0.001;
1090 }
1091 let linear_func =
1092 ((Tensor::arange(0f32, dim as f32, dev)? - min as f64)? / (max as f64 - min as f64))?;
1093 linear_func.clamp(0., 1)
1094 }
1095
1096 pub(crate) fn yarn_get_mscale(scale: f32, mscale: f32) -> f32 {
1097 if scale <= 1. {
1098 return 1.;
1099 }
1100 0.1 * mscale * scale.ln() + 1.
1101 }
1102
1103 #[allow(clippy::too_many_arguments)]
1104 fn new_yarn(
1105 cfg: &DeepSeekV2RopeConfig,
1106 dtype: DType,
1107 dev: &Device,
1108 original_max_position_embeddings: usize,
1109 beta_fast: f32,
1110 beta_slow: f32,
1111 factor: f32,
1112 mscale: f32,
1113 mscale_all_dim: f32,
1114 ) -> Result<Self> {
1115 let freq_extra: Vec<_> = (0..cfg.qk_rope_head_dim)
1116 .step_by(2)
1117 .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / cfg.qk_rope_head_dim as f32))
1118 .collect();
1119 let freq_extra_len = freq_extra.len();
1120 let freq_extra = Tensor::from_vec(freq_extra, freq_extra_len, dev)?;
1121 let freq_inter: Vec<_> = (0..cfg.qk_rope_head_dim)
1122 .step_by(2)
1123 .map(|i| 1f32 / (factor * cfg.rope_theta.powf(i as f32 / cfg.qk_rope_head_dim as f32)))
1124 .collect();
1125 let freq_inter_len = freq_inter.len();
1126 let freq_inter = Tensor::from_vec(freq_inter, (1, freq_inter_len), dev)?;
1127
1128 let (low, high) = Self::yarn_find_correction_range(
1129 beta_fast,
1130 beta_slow,
1131 cfg.qk_rope_head_dim,
1132 cfg.rope_theta,
1133 original_max_position_embeddings,
1134 );
1135 let inv_freq_mask =
1136 (1. - Self::yarn_linear_ramp_mask(low, high, cfg.qk_rope_head_dim / 2, dev)?)?;
1137 let inv_freq = freq_inter
1138 .broadcast_mul(&(1. - &inv_freq_mask)?)?
1139 .broadcast_add(&freq_extra.broadcast_mul(&inv_freq_mask)?)?;
1140
1141 let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
1142 .to_dtype(DType::F32)?
1143 .reshape((cfg.max_position_embeddings, 1))?;
1144 let freqs = t.matmul(&inv_freq)?;
1145
1146 let mscale =
1147 Self::yarn_get_mscale(factor, mscale) / Self::yarn_get_mscale(factor, mscale_all_dim);
1148 let sin = (freqs.sin()? * mscale as f64)?.to_dtype(dtype)?;
1149 let cos = (freqs.cos()? * mscale as f64)?.to_dtype(dtype)?;
1150
1151 Ok(Self { sin, cos })
1152 }
1153
1154 pub fn new(cfg: &DeepSeekV2RopeConfig, dtype: DType, dev: &Device) -> Result<Self> {
1155 match &cfg.rope_scaling {
1156 Some(DeepSeekV2RopeScaling::LinearOrDynamic {
1157 scaling_type: _,
1158 factor: _,
1159 }) => candle_core::bail!("linear and dynamic rope are not implemented yet!"),
1160 Some(DeepSeekV2RopeScaling::Yarn {
1161 original_max_position_embeddings,
1162 beta_fast,
1163 beta_slow,
1164 factor,
1165 mscale,
1166 mscale_all_dim,
1167 scaling_type: _,
1168 }) => Self::new_yarn(
1169 cfg,
1170 dtype,
1171 dev,
1172 *original_max_position_embeddings,
1173 *beta_fast,
1174 *beta_slow,
1175 *factor,
1176 *mscale,
1177 *mscale_all_dim,
1178 ),
1179 None => Self::new_unscaled(cfg, dtype, dev),
1180 }
1181 }
1182
1183 pub fn forward(
1184 &self,
1185 q: &Tensor,
1186 k: &Tensor,
1187 seqlen_offsets: &[usize],
1188 ) -> Result<(Tensor, Tensor)> {
1189 let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
1190
1191 if seqlen_offsets.len() == 1 {
1192 let cos = self.cos.narrow(0, seqlen_offsets[0], seq_len)?;
1193 let sin = self.sin.narrow(0, seqlen_offsets[0], seq_len)?;
1194 let q_embed = candle_nn::rotary_emb::rope_i(&q.contiguous()?, &cos, &sin)?;
1195 let k_embed = candle_nn::rotary_emb::rope_i(&k.contiguous()?, &cos, &sin)?;
1196 Ok((q_embed, k_embed))
1197 } else {
1198 let mut q_embeds = Vec::new();
1199 let mut k_embeds = Vec::new();
1200 for (i, offset) in seqlen_offsets.iter().enumerate() {
1201 let cos = self.cos.narrow(0, *offset, seq_len)?;
1202 let sin = self.sin.narrow(0, *offset, seq_len)?;
1203 let q_embed = candle_nn::rotary_emb::rope_i(
1204 &q.i(i)?.unsqueeze(0)?.contiguous()?,
1205 &cos,
1206 &sin,
1207 )?;
1208 let k_embed = candle_nn::rotary_emb::rope_i(
1209 &k.i(i)?.unsqueeze(0)?.contiguous()?,
1210 &cos,
1211 &sin,
1212 )?;
1213 q_embeds.push(q_embed);
1214 k_embeds.push(k_embed);
1215 }
1216 Ok((Tensor::cat(&q_embeds, 0)?, Tensor::cat(&k_embeds, 0)?))
1217 }
1218 }
1219}
1220
1221#[derive(Debug, Clone)]
1222pub struct Phi4MMRotaryEmbedding {
1223 short_sin: Tensor,
1224 short_cos: Tensor,
1225 long_cos: Option<Tensor>,
1226 long_sin: Option<Tensor>,
1227 original_max_position_embeddings: usize,
1228}
1229
1230#[derive(Debug, Clone, Default, Deserialize, Serialize)]
1231#[serde(rename_all = "lowercase")]
1232pub enum Phi4MMScaledRopeType {
1233 #[serde(alias = "longrope")]
1234 LongRope,
1235 #[default]
1236 Default,
1237}
1238
1239#[derive(Debug, Clone, Deserialize, Serialize)]
1240pub struct Phi4MMRopeScalingConfig {
1241 short_factor: Option<Vec<f64>>,
1242 long_factor: Option<Vec<f64>>,
1243 #[serde(rename = "type")]
1244 scaling_type: Phi4MMScaledRopeType,
1245}
1246
1247impl Phi4MMRotaryEmbedding {
1248 fn new_unscaled(cfg: &Phi4MMConfig, dtype: DType, dev: &Device) -> Result<Self> {
1249 let max_seq_len = cfg.max_position_embeddings;
1250 let dim = (cfg.head_dim() as f64 * cfg.partial_rotary_factor) as usize;
1251
1252 let inv_freq: Vec<_> = (0..dim)
1253 .step_by(2)
1254 .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
1255 .collect();
1256 let inv_freq_len = inv_freq.len();
1257 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
1258 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
1259 .to_dtype(DType::F32)?
1260 .reshape((max_seq_len, 1))?;
1261 let freqs = t.matmul(&inv_freq)?;
1262 let sin = freqs.sin()?.to_dtype(dtype)?;
1263 let cos = freqs.cos()?.to_dtype(dtype)?;
1264 Ok(Self {
1265 short_cos: cos,
1266 short_sin: sin,
1267 long_cos: None,
1268 long_sin: None,
1269 original_max_position_embeddings: cfg.original_max_position_embeddings,
1270 })
1271 }
1272
1273 #[allow(clippy::too_many_arguments)]
1274 fn new_longrope(
1275 short_factor: &[f64],
1276 long_factor: &[f64],
1277 cfg: &Phi4MMConfig,
1278 dtype: DType,
1279 dev: &Device,
1280 ) -> Result<Self> {
1281 let max_seq_len = cfg.max_position_embeddings;
1282 let dim = (cfg.head_dim() as f64 * cfg.partial_rotary_factor) as usize;
1283
1284 let scale =
1286 cfg.max_position_embeddings as f64 / cfg.original_max_position_embeddings as f64;
1287 let scaling_factor = if scale <= 1.0 {
1288 1.0
1289 } else {
1290 (1.0 + scale.ln() / (cfg.original_max_position_embeddings as f64).ln()).sqrt()
1291 };
1292
1293 let inv_freq_short: Vec<_> = (0..dim)
1295 .step_by(2)
1296 .enumerate()
1297 .map(|(k, i)| {
1298 1f32 / (short_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32
1299 })
1300 .collect();
1301 let inv_freq_len_short = inv_freq_short.len();
1302 let inv_freq_short = Tensor::from_vec(inv_freq_short, (1, inv_freq_len_short), dev)?;
1303 let t_short = Tensor::arange(0u32, max_seq_len as u32, dev)?
1304 .to_dtype(DType::F32)?
1305 .reshape((max_seq_len, 1))?;
1306 let freqs_short = t_short.matmul(&inv_freq_short)?;
1307 let sin_short = (freqs_short.sin()?.to_dtype(dtype)? * scaling_factor)?;
1308 let cos_short = (freqs_short.cos()?.to_dtype(dtype)? * scaling_factor)?;
1309
1310 let inv_freq_long: Vec<_> = (0..dim)
1312 .step_by(2)
1313 .enumerate()
1314 .map(|(k, i)| {
1315 1f32 / (long_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32
1316 })
1317 .collect();
1318 let inv_freq_len_long = inv_freq_long.len();
1319 let inv_freq_long = Tensor::from_vec(inv_freq_long, (1, inv_freq_len_long), dev)?;
1320 let t_long = Tensor::arange(0u32, max_seq_len as u32, dev)?
1321 .to_dtype(DType::F32)?
1322 .reshape((max_seq_len, 1))?;
1323 let freqs_long = t_long.matmul(&inv_freq_long)?;
1324 let sin_long = (freqs_long.sin()?.to_dtype(dtype)? * scaling_factor)?;
1325 let cos_long = (freqs_long.cos()?.to_dtype(dtype)? * scaling_factor)?;
1326
1327 Ok(Self {
1328 short_cos: cos_short,
1329 short_sin: sin_short,
1330 long_cos: Some(cos_long),
1331 long_sin: Some(sin_long),
1332 original_max_position_embeddings: cfg.original_max_position_embeddings,
1333 })
1334 }
1335
1336 pub fn new(dtype: DType, cfg: &Phi4MMConfig, dev: &Device) -> Result<Self> {
1337 match &cfg.rope_scaling {
1338 Some(Phi4MMRopeScalingConfig {
1339 scaling_type: Phi4MMScaledRopeType::LongRope,
1340 short_factor: Some(short_factor),
1341 long_factor: Some(long_factor),
1342 }) => Self::new_longrope(short_factor, long_factor, cfg, dtype, dev),
1343
1344 _ => Self::new_unscaled(cfg, dtype, dev),
1345 }
1346 }
1347
1348 fn get_long_or_short_sin_cos(&self, position_ids: &[usize]) -> (&Tensor, &Tensor) {
1350 if self.long_cos.is_none() {
1351 return (&self.short_sin, &self.short_cos);
1352 }
1353 let seq_len = position_ids.iter().max().unwrap() + 1;
1354 if seq_len > self.original_max_position_embeddings {
1355 (
1356 self.long_sin.as_ref().unwrap(),
1357 self.long_cos.as_ref().unwrap(),
1358 )
1359 } else {
1360 (&self.short_sin, &self.short_cos)
1361 }
1362 }
1363
1364 pub fn forward(
1365 &self,
1366 q: &Tensor,
1367 k: &Tensor,
1368 seqlen_offsets: &[usize],
1369 position_ids: &[usize],
1370 ) -> Result<(Tensor, Tensor)> {
1371 let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
1372 let (sin, cos) = self.get_long_or_short_sin_cos(position_ids);
1373
1374 let rot_dim = cos.dim(D::Minus1)? * 2;
1375 let q_rot = q.narrow(D::Minus1, 0, rot_dim)?;
1376 let q_pass = q.narrow(D::Minus1, rot_dim, q.dim(D::Minus1)? - rot_dim)?;
1377 let k_rot = k.narrow(D::Minus1, 0, rot_dim)?;
1378 let k_pass = k.narrow(D::Minus1, rot_dim, k.dim(D::Minus1)? - rot_dim)?;
1379
1380 let (q_rot, k_rot) = if seqlen_offsets.len() == 1 {
1381 let cos = cos.narrow(0, seqlen_offsets[0], seq_len)?;
1382 let sin = sin.narrow(0, seqlen_offsets[0], seq_len)?;
1383 let q_embed = candle_nn::rotary_emb::rope(&q_rot.contiguous()?, &cos, &sin)?;
1384 let k_embed = candle_nn::rotary_emb::rope(&k_rot.contiguous()?, &cos, &sin)?;
1385 (q_embed, k_embed)
1386 } else {
1387 let mut q_embeds = Vec::new();
1388 let mut k_embeds = Vec::new();
1389 for (i, offset) in seqlen_offsets.iter().enumerate() {
1390 let cos = cos.narrow(0, *offset, seq_len)?;
1391 let sin = sin.narrow(0, *offset, seq_len)?;
1392 let q_embed = candle_nn::rotary_emb::rope(
1393 &q_rot.i(i)?.unsqueeze(0)?.contiguous()?,
1394 &cos,
1395 &sin,
1396 )?;
1397 let k_embed = candle_nn::rotary_emb::rope(
1398 &k_rot.i(i)?.unsqueeze(0)?.contiguous()?,
1399 &cos,
1400 &sin,
1401 )?;
1402 q_embeds.push(q_embed);
1403 k_embeds.push(k_embed);
1404 }
1405 let q_rot = Tensor::cat(&q_embeds, 0)?;
1406 let k_rot = Tensor::cat(&k_embeds, 0)?;
1407 (q_rot, k_rot)
1408 };
1409
1410 Ok((
1411 Tensor::cat(&[q_rot, q_pass], D::Minus1)?.contiguous()?,
1412 Tensor::cat(&[k_rot, k_pass], D::Minus1)?.contiguous()?,
1413 ))
1414 }
1415}
1416
1417#[derive(Debug, Clone)]
1418pub struct Gemma3RotaryEmbedding(RotaryEmbedding);
1419
1420#[derive(Debug, Clone, Deserialize, Serialize)]
1421#[serde(rename_all = "lowercase")]
1422pub enum Gemma3ScaledRopeType {
1423 #[serde(alias = "linear")]
1424 Linear,
1425}
1426
1427#[derive(Debug, Clone, Deserialize, Serialize)]
1428pub struct Gemma3RopeScalingConfig {
1429 factor: f64,
1430 rope_type: Gemma3ScaledRopeType,
1431}
1432
1433impl Gemma3RotaryEmbedding {
1434 fn new_linear(
1435 cfg: &Gemma3TextConfig,
1436 factor: f64,
1437 is_gpt_neox: bool,
1438 dtype: DType,
1439 dev: &Device,
1440 ) -> Result<Self> {
1441 let max_seq_len = cfg.max_position_embeddings;
1442 let dim = cfg.head_dim;
1443
1444 let inv_freq: Vec<_> = (0..dim)
1445 .step_by(2)
1446 .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
1447 .collect();
1448 let inv_freq_len = inv_freq.len();
1449 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
1450 let inv_freq = (inv_freq / factor)?;
1451
1452 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
1453 .to_dtype(DType::F32)?
1454 .reshape((max_seq_len, 1))?;
1455 let freqs = t.matmul(&inv_freq)?;
1456 let sin = freqs.sin()?.to_dtype(dtype)?;
1457 let cos = freqs.cos()?.to_dtype(dtype)?;
1458 Ok(Self(RotaryEmbedding {
1459 cos,
1460 sin,
1461 is_gpt_neox,
1462 }))
1463 }
1464
1465 pub fn new(
1466 is_gpt_neox: bool,
1467 dtype: DType,
1468 cfg: &Gemma3TextConfig,
1469 dev: &Device,
1470 ) -> Result<Self> {
1471 match &cfg.rope_scaling {
1472 Some(Gemma3RopeScalingConfig {
1473 rope_type: Gemma3ScaledRopeType::Linear,
1474 factor,
1475 }) => Self::new_linear(cfg, *factor, is_gpt_neox, dtype, dev),
1476
1477 _ => Self::new_linear(cfg, 1.0, is_gpt_neox, dtype, dev),
1478 }
1479 }
1480
1481 pub fn forward(
1482 &self,
1483 q: &Tensor,
1484 k: &Tensor,
1485 seqlen_offsets: &[usize],
1486 ) -> Result<(Tensor, Tensor)> {
1487 self.0.forward(q, k, seqlen_offsets)
1488 }
1489}
1490
1491pub struct DiaRotaryEmbedding {
1492 timescale: Tensor,
1493 dtype: DType,
1494}
1495
1496impl DiaRotaryEmbedding {
1497 pub fn new(
1498 min_timescale: f32,
1499 max_timescale: f32,
1500 head_dim: usize,
1501 device: &Device,
1502 dtype: DType,
1503 ) -> Result<Self> {
1504 assert_eq!(head_dim % 2, 0);
1505 let half_embedding_dim = head_dim / 2;
1506
1507 let fraction = (0..half_embedding_dim).map(|i| 2f32 * i as f32 / head_dim as f32);
1508 let timescale = fraction
1509 .into_iter()
1510 .map(|x| min_timescale * (max_timescale / min_timescale).powf(x))
1511 .collect::<Vec<_>>();
1512
1513 let timescale_len = timescale.len();
1514 let timescale = Tensor::from_vec(timescale, timescale_len, device)?;
1515
1516 Ok(Self { timescale, dtype })
1517 }
1518
1519 pub fn forward(&self, xs: &Tensor, positions: &Tensor) -> Result<Tensor> {
1520 let freqs = positions
1521 .unsqueeze(D::Minus1)?
1522 .unsqueeze(D::Minus1)?
1523 .broadcast_div(&self.timescale)?;
1524
1525 let sin = freqs.sin()?.to_dtype(self.dtype)?;
1526 let cos = freqs.cos()?.to_dtype(self.dtype)?;
1527
1528 let split = xs.chunk(2, D::Minus1)?;
1529 let first_half = &split[0];
1530 let second_half = &split[1];
1531
1532 let first_part = (first_half.broadcast_mul(&cos)? - second_half.broadcast_mul(&sin)?)?;
1533 let second_part = (second_half.broadcast_mul(&cos)? + first_half.broadcast_mul(&sin)?)?;
1534
1535 Tensor::cat(&[first_part, second_part], D::Minus1)
1536 }
1537}
1538#[derive(Debug, Clone)]
1539pub struct QLinear {
1540 inner: QMatMul,
1541 bias: Option<Tensor>,
1542 dtype: DType,
1543}
1544
1545impl QLinear {
1546 pub fn new<R: std::io::Read + std::io::Seek>(
1547 ct: &mut Content<'_, R>,
1548 name: &str,
1549 device: &Device,
1550 ) -> Result<Self> {
1551 let w = ct.tensor(&format!("{name}.weight"), device)?;
1552 let b = ct.tensor(&format!("{name}.bias"), device)?;
1553 let inner = QMatMul::from_qtensor(w)?;
1554 let bias = b.dequantize(device)?;
1555 Ok(Self {
1556 inner,
1557 bias: Some(bias),
1558 dtype: DType::F32,
1559 })
1560 }
1561
1562 pub fn from_linear(linear: Linear) -> Self {
1563 Self {
1564 inner: QMatMul::Tensor(linear.weight().clone()),
1565 bias: linear.bias().cloned(),
1566 dtype: linear.weight().dtype(),
1567 }
1568 }
1569
1570 pub fn from_parts(w: Tensor, b: Option<Tensor>) -> Self {
1571 let dtype = w.dtype();
1572 Self {
1573 inner: QMatMul::Tensor(w),
1574 bias: b,
1575 dtype,
1576 }
1577 }
1578
1579 pub fn from_qparts(w: QTensor, b: Option<Tensor>) -> Self {
1580 if let Some(ref b) = b {
1581 assert_eq!(b.dtype(), DType::F32);
1582 }
1583 Self {
1584 inner: QMatMul::QTensor(Arc::new(w)),
1585 bias: b,
1586 dtype: DType::F32,
1587 }
1588 }
1589
1590 pub fn from_old_and_qmatmul(inner: QMatMul, old: &Self) -> Self {
1591 Self {
1592 inner,
1593 bias: old.bias.clone(),
1594 dtype: old.dtype,
1595 }
1596 }
1597
1598 pub fn inner(&mut self) -> &mut QMatMul {
1599 &mut self.inner
1600 }
1601
1602 pub fn inner_ref(&self) -> &QMatMul {
1603 &self.inner
1604 }
1605
1606 pub fn is_quant(&self) -> bool {
1607 matches!(self.inner, QMatMul::QTensor(_))
1608 }
1609
1610 pub fn bias(&self) -> Option<&Tensor> {
1611 self.bias.as_ref()
1612 }
1613
1614 pub fn bias_mut(&mut self) -> Option<&mut Tensor> {
1615 self.bias.as_mut()
1616 }
1617}
1618
1619impl Module for QLinear {
1620 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
1621 let xs = if self.is_quant() {
1622 xs.to_dtype(DType::F32)?
1623 } else {
1624 xs.clone()
1625 };
1626 if let Some(bias) = &self.bias {
1627 self.inner
1628 .forward(&xs)?
1629 .broadcast_add(bias)?
1630 .to_dtype(self.dtype)
1631 } else {
1632 self.inner.forward(&xs)?.to_dtype(self.dtype)
1633 }
1634 }
1635}
1636
1637#[derive(Debug, Clone)]
1638pub struct RotaryEmbedding {
1639 cos: Tensor,
1640 sin: Tensor,
1641 is_gpt_neox: bool,
1642}
1643
1644impl RotaryEmbedding {
1645 pub fn new(
1646 base: f32,
1647 head_dim: usize,
1648 max_position_embeddings: usize,
1649 device: &Device,
1650 is_gpt_neox: bool,
1651 dtype: DType,
1652 ) -> Result<Self> {
1653 let inv_freq: Vec<_> = (0..head_dim)
1654 .step_by(2)
1655 .map(|i| 1f32 / base.powf(i as f32 / head_dim as f32))
1656 .collect();
1657 let inv_freq_len = inv_freq.len();
1658 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), device)?;
1659 let t = Tensor::arange(0u32, max_position_embeddings as u32, device)?
1660 .to_dtype(DType::F32)?
1661 .reshape((max_position_embeddings, 1))?;
1662 let freqs = t.matmul(&inv_freq)?;
1663 let sin = freqs.sin()?.to_dtype(dtype)?;
1664 let cos = freqs.cos()?.to_dtype(dtype)?;
1665
1666 Ok(Self {
1667 cos,
1668 sin,
1669 is_gpt_neox,
1670 })
1671 }
1672
1673 pub fn new_partial(
1674 base: f32,
1675 rot_dim: usize,
1676 max_position_embeddings: usize,
1677 device: &Device,
1678 is_gpt_neox: bool,
1679 dtype: DType,
1680 ) -> Result<Self> {
1681 let inv_freq: Vec<_> = (0..rot_dim)
1682 .step_by(2)
1683 .map(|i| 1f32 / base.powf(i as f32 / rot_dim as f32))
1684 .collect();
1685 let inv_freq_len = inv_freq.len();
1686 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), device)?;
1687 let t = Tensor::arange(0u32, max_position_embeddings as u32, device)?
1688 .to_dtype(DType::F32)?
1689 .reshape((max_position_embeddings, 1))?;
1690 let freqs = t.matmul(&inv_freq)?;
1691 let sin = freqs.sin()?.to_dtype(dtype)?;
1692 let cos = freqs.cos()?.to_dtype(dtype)?;
1693
1694 Ok(Self {
1695 cos,
1696 sin,
1697 is_gpt_neox,
1698 })
1699 }
1700
1701 pub fn forward(
1702 &self,
1703 q: &Tensor,
1704 k: &Tensor,
1705 seqlen_offsets: &[usize],
1706 ) -> Result<(Tensor, Tensor)> {
1707 let (b_sz, qh, seq_len, n_embd) = q.dims4()?;
1708 let (_b_sz, kh, _seq_len, __n_embd) = k.dims4()?;
1709
1710 let rope = if self.is_gpt_neox {
1711 candle_nn::rotary_emb::rope
1712 } else {
1713 candle_nn::rotary_emb::rope_i
1714 };
1715
1716 if cfg!(feature = "cuda") && qh == kh {
1717 let (cos, sin) = if seqlen_offsets.len() == 1 {
1718 (
1719 self.cos.narrow(0, seqlen_offsets[0], seq_len)?,
1720 self.sin.narrow(0, seqlen_offsets[0], seq_len)?,
1721 )
1722 } else {
1723 let mut cos_s = Vec::new();
1724 let mut sin_s = Vec::new();
1725 for offset in seqlen_offsets {
1726 cos_s.push(self.cos.narrow(0, *offset, seq_len)?);
1727 sin_s.push(self.sin.narrow(0, *offset, seq_len)?);
1728 }
1729 (Tensor::cat(&cos_s, 0)?, Tensor::cat(&sin_s, 0)?)
1730 };
1731
1732 let q_embed = q.transpose(1, 2)?.flatten(0, 1)?;
1733 let k_embed = k.transpose(1, 2)?.flatten(0, 1)?;
1734 mistralrs_quant::rotary::apply_rotary_inplace(
1735 &q_embed,
1736 &k_embed,
1737 &cos,
1738 &sin,
1739 self.is_gpt_neox,
1740 )?;
1741 let mut q = q_embed
1742 .reshape((b_sz, seq_len, qh, n_embd))?
1743 .transpose(1, 2)?;
1744 let mut k = k_embed
1745 .reshape((b_sz, seq_len, kh, n_embd))?
1746 .transpose(1, 2)?;
1747 if !(cfg!(feature = "flash-attn") || cfg!(feature = "flash-attn-v3")) {
1748 q = q.contiguous()?;
1749 k = k.contiguous()?;
1750 }
1751 Ok((q, k))
1752 } else if seqlen_offsets.len() == 1 {
1753 let cos = self.cos.narrow(0, seqlen_offsets[0], seq_len)?;
1754 let sin = self.sin.narrow(0, seqlen_offsets[0], seq_len)?;
1755 let q_embed = rope(&q.contiguous()?, &cos, &sin)?;
1756 let k_embed = rope(&k.contiguous()?, &cos, &sin)?;
1757 Ok((q_embed, k_embed))
1758 } else {
1759 let mut q_embeds = Vec::new();
1760 let mut k_embeds = Vec::new();
1761 for (i, offset) in seqlen_offsets.iter().enumerate() {
1762 let cos = self.cos.narrow(0, *offset, seq_len)?;
1763 let sin = self.sin.narrow(0, *offset, seq_len)?;
1764 let q_embed = rope(&q.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
1765 let k_embed = rope(&k.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
1766 q_embeds.push(q_embed);
1767 k_embeds.push(k_embed);
1768 }
1769 Ok((Tensor::cat(&q_embeds, 0)?, Tensor::cat(&k_embeds, 0)?))
1770 }
1771 }
1772}
1773
1774#[derive(Debug, Clone, Copy, PartialEq, Deserialize, Serialize, Default)]
1775#[serde(rename_all = "lowercase")]
1776pub enum Activation {
1777 #[default]
1778 #[serde(alias = "gelu")]
1779 Gelu,
1780 #[serde(alias = "gelu_new")]
1781 NewGelu,
1782 Relu,
1783 Relu2,
1784 Relu6,
1785 Silu,
1786 Sigmoid,
1787 HardSigmoid,
1788 Swiglu,
1789 Swish,
1790 HardSwish,
1791 Elu(f64),
1792 LeakyRelu(f64),
1793 #[serde(alias = "gelu_pytorch_tanh")]
1794 GeluPytorchTanh,
1795 QuickGelu,
1796}
1797
1798impl Module for Activation {
1799 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
1800 match self {
1801 Self::Gelu => xs.gelu_erf(),
1802 Self::NewGelu => xs.gelu(),
1804 Self::Relu => xs.relu(),
1805 Self::Relu2 => xs.relu()?.sqr(),
1806 Self::Relu6 => xs.clamp(0f32, 6f32),
1807 Self::Silu => xs.silu(),
1808 Self::Sigmoid => candle_nn::ops::sigmoid(xs),
1809 Self::HardSigmoid => candle_nn::ops::hard_sigmoid(xs),
1810 Self::Swiglu => candle_nn::ops::swiglu(xs),
1811 Self::Swish => xs * candle_nn::ops::sigmoid(xs)?,
1812 Self::HardSwish => xs * candle_nn::ops::hard_sigmoid(xs)?,
1813 &Self::Elu(alpha) => xs.elu(alpha),
1814 &Self::LeakyRelu(negative_slope) => candle_nn::ops::leaky_relu(xs, negative_slope),
1815 Self::GeluPytorchTanh => xs.gelu(),
1816 Self::QuickGelu => xs * candle_nn::ops::sigmoid(&(xs * 1.702f64)?),
1817 }
1818 }
1819}
1820
1821impl TryInto<candle_nn::Activation> for Activation {
1822 type Error = candle_core::Error;
1823
1824 fn try_into(self) -> Result<candle_nn::Activation> {
1825 match self {
1826 Self::Gelu => Ok(candle_nn::Activation::Gelu),
1827 Self::Relu => Ok(candle_nn::Activation::Relu),
1828 Self::Silu => Ok(candle_nn::Activation::Silu),
1829 Self::NewGelu => Ok(candle_nn::Activation::NewGelu),
1830 Self::Relu2 => Ok(candle_nn::Activation::Relu2),
1831 Self::Relu6 => Ok(candle_nn::Activation::Relu6),
1832 Self::Sigmoid => Ok(candle_nn::Activation::Sigmoid),
1833 Self::HardSigmoid => Ok(candle_nn::Activation::HardSigmoid),
1834 Self::Swiglu => Ok(candle_nn::Activation::Swiglu),
1835 Self::Swish => Ok(candle_nn::Activation::Swish),
1836 Self::HardSwish => Ok(candle_nn::Activation::HardSwish),
1837 Self::Elu(x) => Ok(candle_nn::Activation::Elu(x)),
1838 Self::LeakyRelu(x) => Ok(candle_nn::Activation::LeakyRelu(x)),
1839 Self::GeluPytorchTanh => Ok(candle_nn::Activation::GeluPytorchTanh),
1840 Self::QuickGelu => candle_core::bail!("No mapping to candle_nn for QuickGelu"),
1841 }
1842 }
1843}
1844
1845#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1846pub struct Conv3dConfig {
1847 pub padding: usize,
1848 pub stride: usize,
1849 pub dilation: usize,
1850 pub groups: usize,
1851}
1852
1853impl Default for Conv3dConfig {
1854 fn default() -> Self {
1855 Self {
1856 padding: 0,
1857 stride: 1,
1858 dilation: 1,
1859 groups: 1,
1860 }
1861 }
1862}
1863
1864pub struct Conv3dNoBias {
1865 conv2d_1: Conv2d,
1866 conv2d_2: Conv2d,
1867}
1868
1869impl Conv3dNoBias {
1870 pub fn new(
1871 in_channels: usize,
1872 out_channels: usize,
1873 kernel_sizes: [usize; 3],
1874 cfg: Conv3dConfig,
1875 vb: ShardedVarBuilder,
1876 ) -> Result<Self> {
1877 let ws = vb.get(
1878 (
1879 out_channels,
1880 in_channels / cfg.groups,
1881 kernel_sizes[0],
1882 kernel_sizes[1],
1883 kernel_sizes[2],
1884 ),
1885 "weight",
1886 )?;
1887
1888 let w1 = ws.i((.., .., 0, .., ..))?;
1892 let w2 = ws.i((.., .., 1, .., ..))?;
1893
1894 let cfg = Conv2dConfig {
1895 padding: cfg.padding,
1896 stride: cfg.stride,
1897 dilation: cfg.dilation,
1898 groups: cfg.groups,
1899 };
1900
1901 Ok(Self {
1902 conv2d_1: Conv2d::new(w1.contiguous()?, None, cfg),
1903 conv2d_2: Conv2d::new(w2.contiguous()?, None, cfg),
1904 })
1905 }
1906}
1907
1908impl Module for Conv3dNoBias {
1909 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
1910 let xs1 = xs.i((.., .., 0, .., ..))?;
1911 let xs2 = xs.i((.., .., 1, .., ..))?;
1912
1913 (self.conv2d_1.forward(&xs1)? + self.conv2d_2.forward(&xs2)?)?.unsqueeze(2)
1914 }
1915}
1916
1917pub trait TensorInfExtend {
1918 fn is_inf(&self) -> Result<Self>
1919 where
1920 Self: Sized;
1921 fn any(&self) -> Result<bool>;
1922}
1923
1924impl TensorInfExtend for Tensor {
1925 fn is_inf(&self) -> Result<Self> {
1926 self.broadcast_eq(&Tensor::new(f32::INFINITY, self.device())?.to_dtype(self.dtype())?)
1927 }
1928
1929 fn any(&self) -> Result<bool> {
1930 let sum = self.sum_all()?;
1931 match self.dtype() {
1932 DType::U8 => Ok(sum.to_scalar::<u8>()? == 0),
1933 DType::U32 => Ok(sum.to_scalar::<u32>()? == 0),
1934 DType::I16 => Ok(sum.to_scalar::<i16>()? == 0),
1935 DType::I32 => Ok(sum.to_scalar::<i32>()? == 0),
1936 DType::I64 => Ok(sum.to_scalar::<i64>()? == 0),
1937 DType::F16 => Ok(sum.to_scalar::<half::f16>()? == half::f16::from_f32_const(0.)),
1938 DType::BF16 => Ok(sum.to_scalar::<half::bf16>()? == half::bf16::from_f32_const(0.)),
1939 DType::F32 => Ok(sum.to_scalar::<f32>()? == 0.),
1940 DType::F64 => Ok(sum.to_scalar::<f64>()? == 0.),
1941 DType::F8E4M3 => Ok(sum.to_scalar::<F8E4M3>()? == F8E4M3::ZERO),
1942 }
1943 }
1944}
1945
1946pub fn clamp_for_f16(xs: &Tensor) -> Result<Tensor> {
1947 let mut max = match xs.dtype() {
1948 DType::U8 => u8::MAX as f32 - 1000.,
1949 DType::U32 => u32::MAX as f32 - 1000.,
1950 DType::I16 => i16::MAX as f32 - 1000.,
1951 DType::I32 => i32::MAX as f32 - 1000.,
1952 DType::I64 => i64::MAX as f32 - 1000.,
1953 DType::F16 => half::f16::MAX.to_f32_const() - 1000.,
1954 DType::BF16 => half::bf16::MAX.to_f32_const() - 1000.,
1955 DType::F32 => f32::MAX - 1000.,
1956 DType::F64 => f64::MAX as f32 - 1000.,
1957 DType::F8E4M3 => F8E4M3::MAX.to_f32() - 1000.,
1958 };
1959 if xs.is_inf()?.any()? {
1960 max -= 1000.;
1961 }
1962 xs.clamp(-max, max)
1963}
1964
1965pub struct FloatInfo {
1966 pub min: f64,
1968 pub max: f64,
1970 pub eps: f64,
1972 pub dtype: DType,
1973}
1974
1975pub trait GetFloatInfo {
1976 fn finfo(&self) -> Result<FloatInfo>;
1977}
1978
1979impl GetFloatInfo for DType {
1980 fn finfo(&self) -> Result<FloatInfo> {
1981 let finfo = match self {
1982 Self::BF16 => FloatInfo {
1983 min: bf16::MIN.to_f64(),
1984 max: bf16::MAX.to_f64(),
1985 eps: bf16::EPSILON.to_f64(),
1986 dtype: DType::BF16,
1987 },
1988 Self::F16 => FloatInfo {
1989 min: f16::MIN.to_f64(),
1990 max: f16::MAX.to_f64(),
1991 eps: f16::EPSILON.to_f64(),
1992 dtype: DType::F16,
1993 },
1994 Self::F32 => FloatInfo {
1995 min: f32::MIN as f64,
1996 max: f32::MAX as f64,
1997 eps: f32::EPSILON as f64,
1998 dtype: DType::F32,
1999 },
2000 Self::F64 => FloatInfo {
2001 min: f64::MIN,
2002 max: f64::MAX,
2003 eps: f64::EPSILON,
2004 dtype: DType::F64,
2005 },
2006 Self::F8E4M3 => FloatInfo {
2007 min: F8E4M3::MIN.to_f64(),
2008 max: F8E4M3::MAX.to_f64(),
2009 eps: F8E4M3::EPSILON.to_f64(),
2010 dtype: DType::F8E4M3,
2011 },
2012 other => {
2013 candle_core::bail!("Expected a float type for `GetFloatInfo`, got {other:?}");
2014 }
2015 };
2016 Ok(finfo)
2017 }
2018}
2019
2020#[derive(Clone)]
2021pub struct Mlp {
2022 pub gate: Arc<dyn QuantMethod>,
2023 pub up: Arc<dyn QuantMethod>,
2024 pub down: Arc<dyn QuantMethod>,
2025 act: Activation,
2026 params: Vec<usize>,
2027}
2028
2029impl Mlp {
2030 pub fn new(
2031 vb: ShardedVarBuilder,
2032 hidden_size: usize,
2033 intermediate_size: usize,
2034 quantization_config: &Option<QuantizedConfig>,
2035 hidden_act: Activation,
2036 comm: &Arc<mistralrs_quant::Comm>,
2037 ) -> Result<Self> {
2038 Ok(Self {
2039 gate: ColumnParallelLayer::new(
2040 hidden_size,
2041 intermediate_size,
2042 quantization_config,
2043 false,
2044 comm,
2045 vb.pp("gate_proj"),
2046 )?,
2047 up: ColumnParallelLayer::new(
2048 hidden_size,
2049 intermediate_size,
2050 quantization_config,
2051 false,
2052 comm,
2053 vb.pp("up_proj"),
2054 )?,
2055 down: RowParallelLayer::new(
2056 intermediate_size,
2057 hidden_size,
2058 quantization_config,
2059 false,
2060 comm,
2061 vb.pp("down_proj"),
2062 )?,
2063 act: hidden_act,
2064 params: vec![hidden_size, intermediate_size],
2065 })
2066 }
2067
2068 pub fn replicate(
2069 params: &[usize],
2070 vb: ShardedVarBuilder,
2071 act: Activation,
2072 comm: &Arc<mistralrs_quant::Comm>,
2073 ) -> Result<Self> {
2074 Self::new(vb, params[0], params[1], &None, act, comm)
2075 }
2076
2077 pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2078 let original_dtype = xs.dtype();
2079 let mut xs = xs.clone();
2080 if let Some(t) = self.gate.quantized_act_type() {
2081 xs = xs.to_dtype(t)?;
2082 }
2083 let lhs = self.gate.forward(&xs)?;
2084 let rhs = self.up.forward(&xs)?;
2085 let mut res = self.down.forward(&candle_nn::ops::mul_and_act(
2086 &lhs,
2087 &rhs,
2088 self.act.try_into()?,
2089 )?)?;
2090 if self.gate.quantized_act_type().is_some() {
2091 res = res.to_dtype(original_dtype)?;
2092 }
2093 Ok(res)
2094 }
2095}
2096
2097impl AnyMoeTrainableLayer for Mlp {}
2098
2099impl MlpLayer for Mlp {
2100 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2101 let original_dtype = xs.dtype();
2102 let mut xs = xs.clone();
2103 if let Some(t) = self.gate.quantized_act_type() {
2104 xs = xs.to_dtype(t)?;
2105 }
2106 let lhs = MatMul.qmethod_matmul(&xs, &*self.gate)?;
2107 let rhs = MatMul.qmethod_matmul(&xs, &*self.up)?;
2108 let mut res = if matches!(
2109 self.act,
2110 Activation::Gelu | Activation::Silu | Activation::Relu
2111 ) {
2112 MatMul.qmethod_matmul(
2113 &candle_nn::ops::mul_and_act(&lhs, &rhs, self.act.try_into()?)?,
2114 &*self.down,
2115 )?
2116 } else {
2117 MatMul.qmethod_matmul(&(&lhs.apply(&self.act)? * &rhs)?, &*self.down)?
2118 };
2119 if self.gate.quantized_act_type().is_some() {
2120 res = res.to_dtype(original_dtype)?;
2121 }
2122 Ok(res)
2123 }
2124 fn get_isq_layers(&mut self) -> Vec<&mut Arc<dyn QuantMethod>> {
2125 vec![&mut self.gate, &mut self.up, &mut self.down]
2126 }
2127 fn clone(&self) -> Box<dyn MlpLayer> {
2128 Box::new(Clone::clone(self))
2129 }
2130 fn get_params(&self) -> &[usize] {
2131 &self.params
2132 }
2133 fn hidden_act(&self) -> Activation {
2134 self.act
2135 }
2136 fn new_added_delta(&self, deltas: Vec<Option<Tensor>>) -> Result<Box<dyn MlpLayer>> {
2138 let gate = if let Some(ref delta) = deltas[0] {
2139 self.gate.add_delta_w(delta)?
2140 } else {
2141 self.gate.clone()
2142 };
2143 let up = if let Some(ref delta) = deltas[1] {
2144 self.up.add_delta_w(delta)?
2145 } else {
2146 self.up.clone()
2147 };
2148 let down = if let Some(ref delta) = deltas[2] {
2149 self.down.add_delta_w(delta)?
2150 } else {
2151 self.down.clone()
2152 };
2153
2154 Ok(Box::new(Self {
2155 gate,
2156 up,
2157 down,
2158 act: self.act,
2159 params: self.params.clone(),
2160 }))
2161 }
2162
2163 fn dtype_device(&self) -> (DType, Device) {
2164 self.gate.dtype_and_device()
2165 }
2166}
2167
2168pub struct AvgPool2d {
2169 kernel_size: usize,
2170 stride: usize,
2171}
2172
2173impl AvgPool2d {
2174 pub fn new(kernel_size: usize, stride: usize) -> Self {
2175 Self {
2176 kernel_size,
2177 stride,
2178 }
2179 }
2180
2181 pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2182 xs.avg_pool2d_with_stride(self.kernel_size, self.stride)
2183 }
2184}
2185
2186pub struct ReflectionPad2d {
2193 padding: (usize, usize, usize, usize),
2194}
2195
2196impl ReflectionPad2d {
2197 pub fn new(padding: (usize, usize, usize, usize)) -> Self {
2198 Self { padding }
2199 }
2200}
2201
2202impl Module for ReflectionPad2d {
2203 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2204 let (pad_left, pad_right, pad_top, pad_bottom) = self.padding;
2205
2206 let (_n, _c, h, w) = xs.dims4()?;
2207
2208 let left_pad = if pad_left > 0 {
2211 let indices: Vec<i64> = (1..=pad_left as i64).rev().collect();
2213 Some(xs.index_select(&Tensor::new(indices, &Device::Cpu)?, 3)?)
2214 } else {
2215 None
2216 };
2217
2218 let right_pad = if pad_right > 0 {
2220 let start = w as i64 - 2;
2222 let indices: Vec<i64> = (0..pad_right as i64).map(|i| start - i).collect();
2223 Some(xs.index_select(&Tensor::new(indices, &Device::Cpu)?, 3)?)
2224 } else {
2225 None
2226 };
2227
2228 let x_padded_width = match (left_pad, right_pad) {
2230 (Some(l), Some(r)) => Tensor::cat(&[l, xs.clone(), r], 3)?,
2231 (Some(l), None) => Tensor::cat(&[l, xs.clone()], 3)?,
2232 (None, Some(r)) => Tensor::cat(&[xs.clone(), r], 3)?,
2233 (None, None) => xs.clone(),
2234 };
2235
2236 let top_pad = if pad_top > 0 {
2239 let indices: Vec<i64> = (1..=pad_top as i64).rev().collect();
2240 Some(x_padded_width.index_select(&Tensor::new(indices, &Device::Cpu)?, 2)?)
2241 } else {
2242 None
2243 };
2244
2245 let bottom_pad = if pad_bottom > 0 {
2247 let start = h as i64 - 2;
2248 let indices: Vec<i64> = (0..pad_bottom as i64).map(|i| start - i).collect();
2249 Some(x_padded_width.index_select(&Tensor::new(indices, &Device::Cpu)?, 2)?)
2250 } else {
2251 None
2252 };
2253
2254 let x_padded = match (top_pad, bottom_pad) {
2256 (Some(t), Some(b)) => Tensor::cat(&[t, x_padded_width, b], 2)?,
2257 (Some(t), None) => Tensor::cat(&[t, x_padded_width], 2)?,
2258 (None, Some(b)) => Tensor::cat(&[x_padded_width, b], 2)?,
2259 (None, None) => x_padded_width,
2260 };
2261
2262 Ok(x_padded)
2263 }
2264}
2265
2266pub struct ScaledEmbedding {
2267 scale: f64,
2268 embedding: Embedding,
2269}
2270
2271impl ScaledEmbedding {
2272 pub fn new(scale: f64, embedding: Embedding) -> Self {
2273 Self { scale, embedding }
2274 }
2275
2276 pub fn embeddings(&self) -> &Tensor {
2277 self.embedding.embeddings()
2278 }
2279}
2280
2281impl Module for ScaledEmbedding {
2282 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2283 xs.apply(&self.embedding)? * self.scale
2284 }
2285}