1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{f32::consts::PI, ops::Mul, str::FromStr, sync::Arc};
4
5use candle_core::{
6 quantized::{QMatMul, QTensor},
7 Context, DType, Device, IndexOp, Result, Tensor, D,
8};
9use candle_nn::{
10 Conv2d, Conv2dConfig, Embedding, GroupNorm, LayerNorm, LayerNormConfig, Linear, Module,
11};
12use float8::F8E4M3;
13use half::{bf16, f16};
14use mistralrs_quant::{
15 AfqLayer, ColumnParallelLayer, QuantMethod, QuantizedConfig, RowParallelLayer,
16 ShardedVarBuilder,
17};
18use serde::{Deserialize, Serialize};
19
20pub use crate::attention::Sdpa;
21pub use crate::layers_masker::CausalMasker;
22pub use crate::layers_utils::repeat_kv;
23use crate::{
24 amoe::{AnyMoeTrainableLayer, MlpLayer},
25 gguf::Content,
26 models::llama,
27 ops::SplitOp,
28 vision_models::{
29 gemma3::config::Gemma3TextConfig,
30 mllama::{MLlamaRopeScaling, MLlamaRopeType, MLlamaTextConfig},
31 phi4::Phi4MMConfig,
32 },
33};
34
35pub use mistralrs_quant::MatMul;
36
37pub fn embedding(
38 in_size: usize,
39 out_size: usize,
40 vb: ShardedVarBuilder,
41 config: &Option<QuantizedConfig>,
42) -> Result<Embedding> {
43 let embeddings = if let Some(QuantizedConfig::Afq { .. }) = config {
45 let afq_layer =
46 AfqLayer::afq_linear_b(out_size, in_size, config.as_ref().unwrap(), false, vb)?;
47 afq_layer.dequantize_w()?
48 } else {
49 vb.get_with_hints((in_size, out_size), "weight", Default::default())?
50 };
51 Ok(Embedding::new(embeddings, out_size))
52}
53
54pub fn layer_norm<C: Into<LayerNormConfig>>(
55 size: usize,
56 config: C,
57 vb: ShardedVarBuilder,
58) -> Result<LayerNorm> {
59 let config = config.into();
60 let weight = vb.get(size, "weight")?;
61 if config.affine {
62 let bias = vb.get(size, "bias")?;
63 Ok(LayerNorm::new(weight, bias, config.eps))
64 } else {
65 Ok(LayerNorm::new_no_bias(weight, config.eps))
66 }
67}
68
69pub fn group_norm(
70 num_groups: usize,
71 num_channels: usize,
72 eps: f64,
73 vb: ShardedVarBuilder,
74) -> Result<GroupNorm> {
75 let weight = vb.get(num_channels, "weight")?;
76 let bias = vb.get(num_channels, "bias")?;
77 GroupNorm::new(weight, bias, num_channels, num_groups, eps)
78}
79
80pub fn conv2d(
81 in_channels: usize,
82 out_channels: usize,
83 kernel_size: usize,
84 cfg: Conv2dConfig,
85 vb: ShardedVarBuilder,
86) -> Result<Conv2d> {
87 let ws = vb.get(
88 (
89 out_channels,
90 in_channels / cfg.groups,
91 kernel_size,
92 kernel_size,
93 ),
94 "weight",
95 )?;
96 let bs = vb.get(out_channels, "bias")?;
97 Ok(Conv2d::new(ws, Some(bs), cfg))
98}
99
100pub fn conv2d_no_bias(
101 in_channels: usize,
102 out_channels: usize,
103 kernel_size: usize,
104 cfg: Conv2dConfig,
105 vb: ShardedVarBuilder,
106) -> Result<Conv2d> {
107 let ws = vb.get(
108 (
109 out_channels,
110 in_channels / cfg.groups,
111 kernel_size,
112 kernel_size,
113 ),
114 "weight",
115 )?;
116 Ok(Conv2d::new(ws, None, cfg))
117}
118
119pub fn linear(in_dim: usize, out_dim: usize, vb: ShardedVarBuilder) -> Result<Linear> {
120 let ws = vb.get((out_dim, in_dim), "weight")?;
121 let bs = vb.get(out_dim, "bias")?;
122 Ok(Linear::new(ws, Some(bs)))
123}
124
125pub fn linear_no_bias(in_dim: usize, out_dim: usize, vb: ShardedVarBuilder) -> Result<Linear> {
126 let ws = vb.get((out_dim, in_dim), "weight")?;
127 Ok(Linear::new(ws, None))
128}
129
130pub fn linear_b(
131 in_dim: usize,
132 out_dim: usize,
133 bias: bool,
134 vb: ShardedVarBuilder,
135) -> Result<Linear> {
136 if bias {
137 linear(in_dim, out_dim, vb)
138 } else {
139 linear_no_bias(in_dim, out_dim, vb)
140 }
141}
142
143#[derive(Debug, Clone)]
144pub struct RmsNorm {
145 eps: f64,
146 weight: Tensor,
147}
148
149impl RmsNorm {
150 pub fn new(size: usize, eps: f64, vb: ShardedVarBuilder) -> Result<Self> {
151 let w = vb.get(size, "weight")?;
152 Ok(Self { eps, weight: w })
153 }
154
155 pub fn new_gemma(size: usize, eps: f64, vb: ShardedVarBuilder) -> Result<Self> {
157 let w = vb.get(size, "weight")?;
158 let w = (w + 1.0)?;
159 Ok(Self { eps, weight: w })
160 }
161
162 pub fn undo_gemma(&self) -> Result<Self> {
164 Ok(Self {
165 eps: self.eps,
166 weight: (&self.weight - 1.0)?,
167 })
168 }
169
170 pub fn from_w(w: Tensor, eps: f64) -> Result<Self> {
171 Ok(Self { eps, weight: w })
172 }
173
174 pub fn weight(&self) -> &Tensor {
175 &self.weight
176 }
177}
178
179impl Module for RmsNorm {
180 fn forward(&self, x: &Tensor) -> Result<Tensor> {
181 candle_nn::ops::rms_norm(&x.contiguous()?, &self.weight, self.eps as f32)
182 }
183}
184
185#[derive(Debug, Clone)]
186pub struct F32RmsNorm {
187 w: Tensor,
188 eps: f64,
189}
190
191impl F32RmsNorm {
192 pub fn new(size: usize, eps: f64, vb: ShardedVarBuilder) -> Result<Self> {
193 Ok(Self {
194 w: vb.get((size,), "weight")?,
195 eps,
196 })
197 }
198
199 pub fn weight(&self) -> &Tensor {
200 &self.w
201 }
202}
203
204impl Module for F32RmsNorm {
205 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
206 let initial_type = xs.dtype();
207 let mut xs = xs.to_dtype(DType::F32)?;
208 let var = xs.powf(2.)?.mean_keepdim(D::Minus1)?;
209 xs = xs.broadcast_mul(&(&var + self.eps)?.recip()?.sqrt()?)?;
210 xs.to_dtype(initial_type)?.broadcast_mul(&self.w)
211 }
212}
213
214#[derive(Debug, Clone)]
215pub struct QRmsNorm {
216 eps: f64,
217 weight: Tensor,
218}
219
220impl QRmsNorm {
221 pub fn new(scale: QTensor, eps: f32) -> Result<Self> {
222 let scale = scale.dequantize(&scale.device())?;
223 Ok(Self {
224 eps: eps as f64,
225 weight: scale,
226 })
227 }
228
229 pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
230 candle_nn::ops::rms_norm(&x.contiguous()?, &self.weight, self.eps as f32)
231 }
232}
233
234#[derive(Debug, Clone)]
236pub struct PhiRotaryEmbedding {
237 short_sin: Tensor,
238 short_cos: Tensor,
239 long_cos: Option<Tensor>,
240 long_sin: Option<Tensor>,
241 original_max_position_embeddings: usize,
242}
243
244#[derive(Debug, Clone, Deserialize, Serialize)]
245#[serde(rename_all = "lowercase")]
246pub enum ScaledRopeType {
247 #[serde(alias = "su")]
248 #[serde(alias = "longrope")]
249 Su,
250 #[serde(alias = "yarn")]
251 Yarn,
252 #[serde(alias = "dynamic")]
253 Dynamic,
254 #[serde(alias = "linear")]
255 Linear,
256}
257
258impl FromStr for ScaledRopeType {
259 type Err = candle_core::Error;
260 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
261 match s {
262 "su" | "longrope" => Ok(Self::Su),
263 "yarn" => Ok(Self::Yarn),
264 "linear" => Ok(Self::Linear),
265 "dynamic" => Ok(Self::Dynamic),
266 _ => Err(candle_core::Error::Msg(
267 "Expected either `su` or `yarn` scaled RoPE type.".to_string(),
268 )),
269 }
270 }
271}
272
273#[derive(Debug, Clone, Deserialize, Serialize)]
274#[serde(untagged)]
275pub enum PhiRopeScalingConfig {
276 Classic {
277 short_factor: Vec<f64>,
278 long_factor: Vec<f64>,
279 #[serde(rename = "type")]
280 scaling_type: ScaledRopeType,
281 },
282 Scaled {
283 short_factor: Vec<f64>,
284 long_factor: Vec<f64>,
285 #[serde(rename = "type")]
286 scaling_type: ScaledRopeType,
287 long_mscale: f64,
288 short_mscale: f64,
289 },
290}
291
292pub struct PhiRopeConfig {
293 pub rope_scaling: Option<PhiRopeScalingConfig>,
294 pub max_position_embeddings: usize,
295 pub original_max_position_embeddings: usize,
296 pub rope_theta: f64,
297 pub head_dim: usize,
298 pub partial_rotary_factor: Option<f64>,
299}
300
301impl PhiRotaryEmbedding {
302 fn new_classic_scaled(
303 short_factor: &[f64],
304 long_factor: &[f64],
305 scaling_type: &ScaledRopeType,
306 cfg: &PhiRopeConfig,
307 dtype: DType,
308 dev: &Device,
309 ) -> Result<Self> {
310 let max_seq_len = cfg.max_position_embeddings;
311 let dim = (cfg.head_dim as f64 * cfg.partial_rotary_factor.unwrap_or(1.)) as usize;
312
313 let scale =
315 cfg.max_position_embeddings as f64 / cfg.original_max_position_embeddings as f64;
316 let scaling_factor = if scale <= 1.0 {
317 1.0
318 } else {
319 match scaling_type {
320 ScaledRopeType::Su => {
321 (1.0 + scale.ln() / (cfg.original_max_position_embeddings as f64).ln()).sqrt()
322 }
323 ScaledRopeType::Yarn => 0.1 * scale.ln() + 1.0,
324 _ => candle_core::bail!("Expected either `su` or `yarn` RoPE"),
325 }
326 };
327
328 let inv_freq_long = (0..dim)
330 .step_by(2)
331 .enumerate()
332 .map(|(k, i)| {
333 (1f64 / (long_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64))) as f32
334 })
335 .collect::<Vec<_>>();
336 let inv_freq_short = (0..dim)
337 .step_by(2)
338 .enumerate()
339 .map(|(k, i)| {
340 (1f64 / (short_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64))) as f32
341 })
342 .collect::<Vec<_>>();
343 let inv_freq_len = inv_freq_long.len();
344
345 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
346 .to_dtype(DType::F32)?
347 .reshape((max_seq_len, 1))?;
348
349 let inv_freq_long = Tensor::from_vec(inv_freq_long, (1, inv_freq_len), dev)?;
351 let freqs_long = t.matmul(&inv_freq_long)?;
352 let long_sin = freqs_long.sin()?.mul(scaling_factor)?.to_dtype(dtype)?;
353 let long_cos = freqs_long.cos()?.mul(scaling_factor)?.to_dtype(dtype)?;
354
355 let inv_freq_short =
357 Tensor::from_vec(inv_freq_short, (1, inv_freq_len), dev)?.to_dtype(DType::F32)?;
358 let freqs_short = t.matmul(&inv_freq_short)?;
359 let short_sin = freqs_short.sin()?.mul(scaling_factor)?.to_dtype(dtype)?;
360 let short_cos = freqs_short.cos()?.mul(scaling_factor)?.to_dtype(dtype)?;
361
362 Ok(Self {
363 short_cos,
364 short_sin,
365 long_cos: Some(long_cos),
366 long_sin: Some(long_sin),
367 original_max_position_embeddings: cfg.original_max_position_embeddings,
368 })
369 }
370
371 fn new_unscaled(cfg: &PhiRopeConfig, dtype: DType, dev: &Device) -> Result<Self> {
372 let max_seq_len = cfg.max_position_embeddings;
373 let dim = (cfg.head_dim as f64 * cfg.partial_rotary_factor.unwrap_or(1.)) as usize;
374
375 let inv_freq: Vec<_> = (0..dim)
376 .step_by(2)
377 .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
378 .collect();
379 let inv_freq_len = inv_freq.len();
380 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
381 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
382 .to_dtype(DType::F32)?
383 .reshape((max_seq_len, 1))?;
384 let freqs = t.matmul(&inv_freq)?;
385 let sin = freqs.sin()?.to_dtype(dtype)?;
386 let cos = freqs.cos()?.to_dtype(dtype)?;
387 Ok(Self {
388 short_cos: cos,
389 short_sin: sin,
390 long_cos: None,
391 long_sin: None,
392 original_max_position_embeddings: cfg.original_max_position_embeddings,
393 })
394 }
395
396 #[allow(clippy::too_many_arguments)]
397 fn new_scaled(
398 short_factor: &[f64],
399 long_factor: &[f64],
400 scaling_type: &ScaledRopeType,
401 long_mscale: f64,
402 short_mscale: f64,
403 cfg: &PhiRopeConfig,
404 dtype: DType,
405 dev: &Device,
406 ) -> Result<Self> {
407 let max_seq_len = cfg.max_position_embeddings;
408 let dim = (cfg.head_dim as f64 * cfg.partial_rotary_factor.unwrap_or(1.)) as usize;
409
410 if !matches!(scaling_type, ScaledRopeType::Su) {
411 candle_core::bail!("Scaled Phi3 RoPE (non-classic scaled, with mscales) must have type `su`/`longrope`.");
412 }
413
414 if short_factor.len() != dim / 2 {
415 candle_core::bail!(
416 "Misaligned length {}, expected {} for `su`/`longrope` short rescale factors",
417 short_factor.len(),
418 dim / 2
419 );
420 }
421 if long_factor.len() != dim / 2 {
422 candle_core::bail!(
423 "Misaligned length {}, expected {} for `su`/`longrope` long rescale factors",
424 long_factor.len(),
425 dim / 2
426 );
427 }
428
429 let inv_freq_short: Vec<_> = (0..dim)
431 .step_by(2)
432 .enumerate()
433 .map(|(k, i)| {
434 1f32 / (short_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32
435 })
436 .collect();
437 let inv_freq_len_short = inv_freq_short.len();
438 let inv_freq_short = Tensor::from_vec(inv_freq_short, (1, inv_freq_len_short), dev)?;
439 let t_short = Tensor::arange(0u32, max_seq_len as u32, dev)?
440 .to_dtype(DType::F32)?
441 .reshape((max_seq_len, 1))?;
442 let freqs_short = t_short.matmul(&inv_freq_short)?;
443 let sin_short = (freqs_short.sin()?.to_dtype(dtype)? * short_mscale)?;
444 let cos_short = (freqs_short.cos()?.to_dtype(dtype)? * short_mscale)?;
445
446 let inv_freq_long: Vec<_> = (0..dim)
448 .step_by(2)
449 .enumerate()
450 .map(|(k, i)| {
451 1f32 / (long_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32
452 })
453 .collect();
454 let inv_freq_len_long = inv_freq_long.len();
455 let inv_freq_long = Tensor::from_vec(inv_freq_long, (1, inv_freq_len_long), dev)?;
456 let t_long = Tensor::arange(0u32, max_seq_len as u32, dev)?
457 .to_dtype(DType::F32)?
458 .reshape((max_seq_len, 1))?;
459 let freqs_long = t_long.matmul(&inv_freq_long)?;
460 let sin_long = (freqs_long.sin()?.to_dtype(dtype)? * long_mscale)?;
461 let cos_long = (freqs_long.cos()?.to_dtype(dtype)? * long_mscale)?;
462 Ok(Self {
463 short_cos: cos_short,
464 short_sin: sin_short,
465 long_cos: Some(cos_long),
466 long_sin: Some(sin_long),
467 original_max_position_embeddings: cfg.original_max_position_embeddings,
468 })
469 }
470
471 pub fn new(dtype: DType, cfg: impl Into<PhiRopeConfig>, dev: &Device) -> Result<Self> {
472 let cfg: PhiRopeConfig = cfg.into();
473
474 match &cfg.rope_scaling {
475 Some(PhiRopeScalingConfig::Classic {
476 short_factor,
477 long_factor,
478 scaling_type,
479 }) => {
480 Self::new_classic_scaled(short_factor, long_factor, scaling_type, &cfg, dtype, dev)
481 }
482
483 Some(PhiRopeScalingConfig::Scaled {
484 short_factor,
485 long_factor,
486 scaling_type,
487 long_mscale,
488 short_mscale,
489 }) => Self::new_scaled(
490 short_factor,
491 long_factor,
492 scaling_type,
493 *long_mscale,
494 *short_mscale,
495 &cfg,
496 dtype,
497 dev,
498 ),
499
500 None => Self::new_unscaled(&cfg, dtype, dev),
501 }
502 }
503
504 fn get_long_or_short_sin_cos(&self, position_ids: &[usize]) -> (&Tensor, &Tensor) {
506 if self.long_cos.is_none() {
507 return (&self.short_sin, &self.short_cos);
508 }
509 let seq_len = position_ids.iter().max().unwrap() + 1;
510 if seq_len > self.original_max_position_embeddings {
511 (
512 self.long_sin.as_ref().unwrap(),
513 self.long_cos.as_ref().unwrap(),
514 )
515 } else {
516 (&self.short_sin, &self.short_cos)
517 }
518 }
519
520 pub fn forward(
521 &self,
522 q: &Tensor,
523 k: &Tensor,
524 seqlen_offsets: &[usize],
525 position_ids: &[usize],
526 ) -> Result<(Tensor, Tensor)> {
527 let (sin, cos) = self.get_long_or_short_sin_cos(position_ids);
528 let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
529
530 let rot_dim = cos.dim(D::Minus1)? * 2;
531
532 if rot_dim != q.dim(D::Minus1)? {
534 let rot_dim = cos.dim(D::Minus1)? * 2;
535 let q_rot = q.narrow(D::Minus1, 0, rot_dim)?;
536 let q_pass = q.narrow(D::Minus1, rot_dim, q.dim(D::Minus1)? - rot_dim)?;
537 let k_rot = k.narrow(D::Minus1, 0, rot_dim)?;
538 let k_pass = k.narrow(D::Minus1, rot_dim, k.dim(D::Minus1)? - rot_dim)?;
539
540 let (q_rot, k_rot) = if seqlen_offsets.len() == 1 {
541 let cos = cos.narrow(0, seqlen_offsets[0], seq_len)?;
542 let sin = sin.narrow(0, seqlen_offsets[0], seq_len)?;
543 let q_embed = candle_nn::rotary_emb::rope(&q_rot.contiguous()?, &cos, &sin)?;
544 let k_embed = candle_nn::rotary_emb::rope(&k_rot.contiguous()?, &cos, &sin)?;
545 (q_embed, k_embed)
546 } else {
547 let mut q_embeds = Vec::new();
548 let mut k_embeds = Vec::new();
549 for (i, offset) in seqlen_offsets.iter().enumerate() {
550 let cos = cos.narrow(0, *offset, seq_len)?;
551 let sin = sin.narrow(0, *offset, seq_len)?;
552 let q_embed = candle_nn::rotary_emb::rope(
553 &q_rot.i(i)?.unsqueeze(0)?.contiguous()?,
554 &cos,
555 &sin,
556 )?;
557 let k_embed = candle_nn::rotary_emb::rope(
558 &k_rot.i(i)?.unsqueeze(0)?.contiguous()?,
559 &cos,
560 &sin,
561 )?;
562 q_embeds.push(q_embed);
563 k_embeds.push(k_embed);
564 }
565 let q_rot = Tensor::cat(&q_embeds, 0)?;
566 let k_rot = Tensor::cat(&k_embeds, 0)?;
567 (q_rot, k_rot)
568 };
569
570 Ok((
571 Tensor::cat(&[q_rot, q_pass], D::Minus1)?.contiguous()?,
572 Tensor::cat(&[k_rot, k_pass], D::Minus1)?.contiguous()?,
573 ))
574 } else if seqlen_offsets.len() == 1 {
575 let cos = cos.narrow(0, seqlen_offsets[0], seq_len)?;
576 let sin = sin.narrow(0, seqlen_offsets[0], seq_len)?;
577 let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
578 let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
579 Ok((q_embed, k_embed))
580 } else {
581 let mut q_embeds = Vec::new();
582 let mut k_embeds = Vec::new();
583 for (i, offset) in seqlen_offsets.iter().enumerate() {
584 let cos = cos.narrow(0, *offset, seq_len)?;
585 let sin = sin.narrow(0, *offset, seq_len)?;
586 let q_embed =
587 candle_nn::rotary_emb::rope(&q.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
588 let k_embed =
589 candle_nn::rotary_emb::rope(&k.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
590 q_embeds.push(q_embed);
591 k_embeds.push(k_embed);
592 }
593 Ok((Tensor::cat(&q_embeds, 0)?, Tensor::cat(&k_embeds, 0)?))
594 }
595 }
596}
597
598#[derive(Debug, Clone)]
600pub struct Llama3RotaryEmbedding(RotaryEmbedding);
601
602#[derive(Debug, Clone, Deserialize, Serialize, Default)]
603pub enum Llama3RopeType {
604 #[serde(rename = "llama3")]
605 Llama3,
606 #[default]
607 #[serde(rename = "default")]
608 Default,
609}
610
611#[derive(Debug, Clone, Deserialize, Serialize, Default)]
612pub struct Llama3RopeConfig {
613 pub factor: f32,
614 pub low_freq_factor: f32,
615 pub high_freq_factor: f32,
616 pub original_max_position_embeddings: usize,
617 pub rope_type: Llama3RopeType,
618}
619
620fn calculate_default_inv_freq(cfg: &llama::Config) -> Vec<f32> {
621 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
622 (0..head_dim)
623 .step_by(2)
624 .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
625 .collect()
626}
627
628impl Llama3RotaryEmbedding {
630 pub fn new_llama3(
631 dtype: DType,
632 cfg: &llama::Config,
633 dev: &Device,
634 is_gpt_neox: bool,
635 ) -> Result<Self> {
636 match &cfg.rope_scaling {
637 None
638 | Some(Llama3RopeConfig {
639 rope_type: Llama3RopeType::Default,
640 ..
641 }) => Ok(Self(RotaryEmbedding::new(
642 cfg.rope_theta,
643 cfg.hidden_size / cfg.num_attention_heads,
644 cfg.max_position_embeddings,
645 dev,
646 is_gpt_neox,
647 dtype,
648 )?)),
649 Some(rope_scaling) => {
650 let low_freq_wavelen = rope_scaling.original_max_position_embeddings as f32
651 / rope_scaling.low_freq_factor;
652 let high_freq_wavelen = rope_scaling.original_max_position_embeddings as f32
653 / rope_scaling.high_freq_factor;
654
655 let inv_freq = calculate_default_inv_freq(cfg)
656 .into_iter()
657 .map(|freq| {
658 let wavelen = 2. * PI / freq;
659 if wavelen < high_freq_wavelen {
660 freq
661 } else if wavelen > low_freq_wavelen {
662 freq / rope_scaling.factor
663 } else {
664 let smooth = (rope_scaling.original_max_position_embeddings as f32
665 / wavelen
666 - rope_scaling.low_freq_factor)
667 / (rope_scaling.high_freq_factor - rope_scaling.low_freq_factor);
668 (1. - smooth) * freq / rope_scaling.factor + smooth * freq
669 }
670 })
671 .collect::<Vec<_>>();
672 let inv_freq_len = inv_freq.len();
673 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
674
675 let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
676 .to_dtype(DType::F32)?
677 .reshape((cfg.max_position_embeddings, 1))?;
678 let freqs = t.matmul(&inv_freq)?;
679 let sin = freqs.sin()?.to_dtype(dtype)?;
680 let cos = freqs.cos()?.to_dtype(dtype)?;
681 Ok(Self(RotaryEmbedding {
682 sin,
683 cos,
684 is_gpt_neox,
685 }))
686 }
687 }
688 }
689
690 pub fn new_mllama3(
691 dtype: DType,
692 cfg: &MLlamaTextConfig,
693 dev: &Device,
694 is_gpt_neox: bool,
695 ) -> Result<Self> {
696 match &cfg.rope_scaling {
697 None
698 | Some(MLlamaRopeScaling {
699 rope_type: MLlamaRopeType::Default,
700 ..
701 }) => Ok(Self(RotaryEmbedding::new(
702 cfg.rope_theta,
703 cfg.hidden_size / cfg.num_attention_heads,
704 cfg.max_position_embeddings,
705 dev,
706 is_gpt_neox,
707 dtype,
708 )?)),
709 Some(MLlamaRopeScaling {
710 rope_type: MLlamaRopeType::Llama3,
711 original_max_position_embeddings,
712 factor,
713 attention_factor: _,
714 beta_fast: _,
715 beta_slow: _,
716 short_factor: _,
717 long_factor: _,
718 low_freq_factor,
719 high_freq_factor,
720 }) => {
721 let factor = factor.context("MLlama Llama3 RoPE needs `factor` parameter.")?;
722 let low_freq_factor = low_freq_factor
723 .context("MLlama Llama3 RoPE needs `low_freq_factor` parameter.")?;
724 let high_freq_factor = high_freq_factor
725 .context("MLlama Llama3 RoPE needs `high_freq_factor` parameter.")?;
726
727 let low_freq_wavelen = *original_max_position_embeddings as f32 / low_freq_factor;
728 let high_freq_wavelen = *original_max_position_embeddings as f32 / high_freq_factor;
729
730 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
731
732 let inv_freq = (0..head_dim)
733 .step_by(2)
734 .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
735 .map(|freq| {
736 let wavelen = 2. * PI / freq;
737 if wavelen < high_freq_wavelen {
738 freq
739 } else if wavelen > low_freq_wavelen {
740 freq / factor
741 } else {
742 let smooth = (*original_max_position_embeddings as f32 / wavelen
743 - low_freq_factor)
744 / (high_freq_factor - low_freq_factor);
745 (1. - smooth) * freq / factor + smooth * freq
746 }
747 })
748 .collect::<Vec<_>>();
749 let inv_freq_len = inv_freq.len();
750 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
751
752 let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
753 .to_dtype(DType::F32)?
754 .reshape((cfg.max_position_embeddings, 1))?;
755 let freqs = t.matmul(&inv_freq)?;
756 let sin = freqs.sin()?.to_dtype(dtype)?;
757 let cos = freqs.cos()?.to_dtype(dtype)?;
758 Ok(Self(RotaryEmbedding {
759 sin,
760 cos,
761 is_gpt_neox,
762 }))
763 }
764 Some(MLlamaRopeScaling {
765 rope_type: other, ..
766 }) => {
767 candle_core::bail!(
768 "MLlama doesn't support any other RoPE type than `llama3`, got {other:?}"
769 )
770 }
771 }
772 }
773
774 pub fn forward(
775 &self,
776 q: &Tensor,
777 k: &Tensor,
778 seqlen_offsets: &[usize],
779 ) -> Result<(Tensor, Tensor)> {
780 self.0.forward(q, k, seqlen_offsets)
781 }
782}
783
784#[derive(Debug, Clone)]
786pub struct Qwen2VLRotaryEmbedding {
787 inv_freq: Tensor,
788 mrope_section: Vec<usize>,
789}
790
791impl Qwen2VLRotaryEmbedding {
792 pub fn new(
793 base: f32,
794 head_dim: usize,
795 device: &Device,
796 mrope_section: Vec<usize>,
797 ) -> Result<Self> {
798 let inv_freq: Vec<_> = (0..head_dim)
799 .step_by(2)
800 .map(|i| 1f32 / base.powf(i as f32 / head_dim as f32))
801 .collect();
802 let inv_freq_len = inv_freq.len();
803 let inv_freq = Tensor::from_vec(inv_freq, (inv_freq_len,), device)?.to_dtype(DType::F32)?;
804 Ok(Self {
805 inv_freq,
806 mrope_section,
807 })
808 }
809
810 pub fn compute_cos_sin(&self, position_ids: &Tensor, dtype: DType) -> Result<(Tensor, Tensor)> {
812 let inv_freq_expanded =
813 self.inv_freq
814 .reshape((1, 1, (), 1))?
815 .repeat((3, position_ids.dim(1)?, 1, 1))?;
816 let position_ids_expanded = position_ids.unsqueeze(2)?;
817 let freqs = inv_freq_expanded
818 .matmul(&position_ids_expanded.to_dtype(inv_freq_expanded.dtype())?)?
819 .transpose(2, 3)?;
820 let cos = freqs.cos()?;
821 let sin = freqs.sin()?;
822
823 let cos = Tensor::cat(
824 &cos.split(&self.mrope_section, D::Minus1)?
825 .into_iter()
826 .enumerate()
827 .map(|(i, m)| m.i(i % 3))
828 .collect::<Result<Vec<_>>>()?,
829 D::Minus1,
830 )?
831 .squeeze(0)?
832 .to_dtype(dtype)?
833 .contiguous()?;
834 let sin = Tensor::cat(
835 &sin.split(&self.mrope_section, D::Minus1)?
836 .into_iter()
837 .enumerate()
838 .map(|(i, m)| m.i(i % 3))
839 .collect::<Result<Vec<_>>>()?,
840 D::Minus1,
841 )?
842 .squeeze(0)?
843 .to_dtype(dtype)?
844 .contiguous()?;
845
846 Ok((cos, sin))
847 }
848
849 pub fn forward(
851 &self,
852 (cos, sin): &(Tensor, Tensor),
853 q: &mut Tensor,
854 k: &mut Tensor,
855 ) -> Result<()> {
856 *q = candle_nn::rotary_emb::rope(&q.contiguous()?, cos, sin)?;
857 *k = candle_nn::rotary_emb::rope(&k.contiguous()?, cos, sin)?;
858 Ok(())
859 }
860}
861
862#[derive(Debug, Clone)]
863pub struct Qwen2_5VLRotaryEmbedding {
864 inv_freq: Tensor,
865 mrope_section: Vec<usize>,
866}
867
868impl Qwen2_5VLRotaryEmbedding {
869 pub fn new(
870 base: f32,
871 head_dim: usize,
872 device: &Device,
873 mrope_section: Vec<usize>,
874 ) -> Result<Self> {
875 let inv_freq: Vec<_> = (0..head_dim)
876 .step_by(2)
877 .map(|i| 1f32 / base.powf(i as f32 / head_dim as f32))
878 .collect();
879 let inv_freq_len = inv_freq.len();
880 let inv_freq = Tensor::from_vec(inv_freq, (inv_freq_len,), device)?.to_dtype(DType::F32)?;
881 Ok(Self {
882 inv_freq,
883 mrope_section,
884 })
885 }
886
887 pub fn compute_cos_sin(&self, position_ids: &Tensor, dtype: DType) -> Result<(Tensor, Tensor)> {
889 let inv_freq_expanded =
890 self.inv_freq
891 .reshape((1, 1, (), 1))?
892 .repeat((3, position_ids.dim(1)?, 1, 1))?;
893 let position_ids_expanded = position_ids.unsqueeze(2)?;
894 let freqs = inv_freq_expanded
895 .matmul(&position_ids_expanded.to_dtype(inv_freq_expanded.dtype())?)?
896 .transpose(2, 3)?;
897 let cos = freqs.cos()?;
898 let sin = freqs.sin()?;
899
900 let cos = Tensor::cat(
901 &cos.split(&self.mrope_section, D::Minus1)?
902 .into_iter()
903 .enumerate()
904 .map(|(i, m)| m.i(i % 3))
905 .collect::<Result<Vec<_>>>()?,
906 D::Minus1,
907 )?
908 .squeeze(0)?
909 .to_dtype(dtype)?
910 .contiguous()?;
911 let sin = Tensor::cat(
912 &sin.split(&self.mrope_section, D::Minus1)?
913 .into_iter()
914 .enumerate()
915 .map(|(i, m)| m.i(i % 3))
916 .collect::<Result<Vec<_>>>()?,
917 D::Minus1,
918 )?
919 .squeeze(0)?
920 .to_dtype(dtype)?
921 .contiguous()?;
922
923 Ok((cos, sin))
924 }
925
926 pub fn forward(
927 &self,
928 (cos, sin): &(Tensor, Tensor),
929 q: &mut Tensor,
930 k: &mut Tensor,
931 ) -> Result<()> {
932 *q = candle_nn::rotary_emb::rope(&q.contiguous()?, cos, sin)?;
933 *k = candle_nn::rotary_emb::rope(&k.contiguous()?, cos, sin)?;
934 Ok(())
935 }
936}
937
938#[derive(Debug, Clone)]
939pub struct DeepSeekV2RotaryEmbedding {
940 sin: Tensor,
941 cos: Tensor,
942}
943
944#[derive(Debug, Clone, Deserialize, Serialize)]
945#[serde(untagged)]
946pub enum DeepSeekV2RopeScaling {
947 Yarn {
948 original_max_position_embeddings: usize,
949 beta_fast: f32,
950 beta_slow: f32,
951 mscale: f32,
952 mscale_all_dim: f32,
953 factor: f32,
954 #[serde(rename = "type")]
955 scaling_type: ScaledRopeType,
956 },
957 LinearOrDynamic {
958 #[serde(rename = "type")]
959 scaling_type: ScaledRopeType,
960 factor: f64,
961 },
962}
963
964pub struct DeepSeekV2RopeConfig {
965 pub rope_scaling: Option<DeepSeekV2RopeScaling>,
966 pub max_position_embeddings: usize,
967 pub rope_theta: f32,
968 pub qk_rope_head_dim: usize,
969}
970
971impl DeepSeekV2RotaryEmbedding {
972 fn new_unscaled(cfg: &DeepSeekV2RopeConfig, dtype: DType, dev: &Device) -> Result<Self> {
973 let max_seq_len = cfg.max_position_embeddings;
974 let dim = cfg.qk_rope_head_dim;
975
976 let inv_freq: Vec<_> = (0..dim)
977 .step_by(2)
978 .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / dim as f32))
979 .collect();
980 let inv_freq_len = inv_freq.len();
981 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
982 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
983 .to_dtype(DType::F32)?
984 .reshape((max_seq_len, 1))?;
985 let freqs = t.matmul(&inv_freq)?;
986
987 let sin = freqs.sin()?.to_dtype(dtype)?;
988 let cos = freqs.cos()?.to_dtype(dtype)?;
989
990 Ok(Self { sin, cos })
991 }
992
993 fn yarn_find_correction_dim(
994 num_rot: f32,
995 dim: usize,
996 base: f32,
997 max_position_embeddings: usize,
998 ) -> f32 {
999 (dim as f32 * (max_position_embeddings as f32 / (num_rot * 2. * PI)).ln())
1000 / (2. * base.ln())
1001 }
1002
1003 fn yarn_find_correction_range(
1004 low_rot: f32,
1005 high_rot: f32,
1006 dim: usize,
1007 base: f32,
1008 max_position_embeddings: usize,
1009 ) -> (f32, f32) {
1010 let low =
1011 Self::yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings).floor();
1012 let high =
1013 Self::yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings).ceil();
1014 (low.max(0.), high.min(dim as f32 - 1.))
1015 }
1016
1017 fn yarn_linear_ramp_mask(min: f32, mut max: f32, dim: usize, dev: &Device) -> Result<Tensor> {
1018 if min == max {
1019 max += 0.001;
1021 }
1022 let linear_func =
1023 ((Tensor::arange(0f32, dim as f32, dev)? - min as f64)? / (max as f64 - min as f64))?;
1024 linear_func.clamp(0., 1)
1025 }
1026
1027 pub(crate) fn yarn_get_mscale(scale: f32, mscale: f32) -> f32 {
1028 if scale <= 1. {
1029 return 1.;
1030 }
1031 0.1 * mscale * scale.ln() + 1.
1032 }
1033
1034 #[allow(clippy::too_many_arguments)]
1035 fn new_yarn(
1036 cfg: &DeepSeekV2RopeConfig,
1037 dtype: DType,
1038 dev: &Device,
1039 original_max_position_embeddings: usize,
1040 beta_fast: f32,
1041 beta_slow: f32,
1042 factor: f32,
1043 mscale: f32,
1044 mscale_all_dim: f32,
1045 ) -> Result<Self> {
1046 let freq_extra: Vec<_> = (0..cfg.qk_rope_head_dim)
1047 .step_by(2)
1048 .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / cfg.qk_rope_head_dim as f32))
1049 .collect();
1050 let freq_extra_len = freq_extra.len();
1051 let freq_extra = Tensor::from_vec(freq_extra, freq_extra_len, dev)?;
1052 let freq_inter: Vec<_> = (0..cfg.qk_rope_head_dim)
1053 .step_by(2)
1054 .map(|i| 1f32 / (factor * cfg.rope_theta.powf(i as f32 / cfg.qk_rope_head_dim as f32)))
1055 .collect();
1056 let freq_inter_len = freq_inter.len();
1057 let freq_inter = Tensor::from_vec(freq_inter, (1, freq_inter_len), dev)?;
1058
1059 let (low, high) = Self::yarn_find_correction_range(
1060 beta_fast,
1061 beta_slow,
1062 cfg.qk_rope_head_dim,
1063 cfg.rope_theta,
1064 original_max_position_embeddings,
1065 );
1066 let inv_freq_mask =
1067 (1. - Self::yarn_linear_ramp_mask(low, high, cfg.qk_rope_head_dim / 2, dev)?)?;
1068 let inv_freq = freq_inter
1069 .broadcast_mul(&(1. - &inv_freq_mask)?)?
1070 .broadcast_add(&freq_extra.broadcast_mul(&inv_freq_mask)?)?;
1071
1072 let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
1073 .to_dtype(DType::F32)?
1074 .reshape((cfg.max_position_embeddings, 1))?;
1075 let freqs = t.matmul(&inv_freq)?;
1076
1077 let mscale =
1078 Self::yarn_get_mscale(factor, mscale) / Self::yarn_get_mscale(factor, mscale_all_dim);
1079 let sin = (freqs.sin()? * mscale as f64)?.to_dtype(dtype)?;
1080 let cos = (freqs.cos()? * mscale as f64)?.to_dtype(dtype)?;
1081
1082 Ok(Self { sin, cos })
1083 }
1084
1085 pub fn new(cfg: &DeepSeekV2RopeConfig, dtype: DType, dev: &Device) -> Result<Self> {
1086 match &cfg.rope_scaling {
1087 Some(DeepSeekV2RopeScaling::LinearOrDynamic {
1088 scaling_type: _,
1089 factor: _,
1090 }) => candle_core::bail!("linear and dynamic rope are not implemented yet!"),
1091 Some(DeepSeekV2RopeScaling::Yarn {
1092 original_max_position_embeddings,
1093 beta_fast,
1094 beta_slow,
1095 factor,
1096 mscale,
1097 mscale_all_dim,
1098 scaling_type: _,
1099 }) => Self::new_yarn(
1100 cfg,
1101 dtype,
1102 dev,
1103 *original_max_position_embeddings,
1104 *beta_fast,
1105 *beta_slow,
1106 *factor,
1107 *mscale,
1108 *mscale_all_dim,
1109 ),
1110 None => Self::new_unscaled(cfg, dtype, dev),
1111 }
1112 }
1113
1114 pub fn forward(
1115 &self,
1116 q: &Tensor,
1117 k: &Tensor,
1118 seqlen_offsets: &[usize],
1119 ) -> Result<(Tensor, Tensor)> {
1120 let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
1121
1122 if seqlen_offsets.len() == 1 {
1123 let cos = self.cos.narrow(0, seqlen_offsets[0], seq_len)?;
1124 let sin = self.sin.narrow(0, seqlen_offsets[0], seq_len)?;
1125 let q_embed = candle_nn::rotary_emb::rope_i(&q.contiguous()?, &cos, &sin)?;
1126 let k_embed = candle_nn::rotary_emb::rope_i(&k.contiguous()?, &cos, &sin)?;
1127 Ok((q_embed, k_embed))
1128 } else {
1129 let mut q_embeds = Vec::new();
1130 let mut k_embeds = Vec::new();
1131 for (i, offset) in seqlen_offsets.iter().enumerate() {
1132 let cos = self.cos.narrow(0, *offset, seq_len)?;
1133 let sin = self.sin.narrow(0, *offset, seq_len)?;
1134 let q_embed = candle_nn::rotary_emb::rope_i(
1135 &q.i(i)?.unsqueeze(0)?.contiguous()?,
1136 &cos,
1137 &sin,
1138 )?;
1139 let k_embed = candle_nn::rotary_emb::rope_i(
1140 &k.i(i)?.unsqueeze(0)?.contiguous()?,
1141 &cos,
1142 &sin,
1143 )?;
1144 q_embeds.push(q_embed);
1145 k_embeds.push(k_embed);
1146 }
1147 Ok((Tensor::cat(&q_embeds, 0)?, Tensor::cat(&k_embeds, 0)?))
1148 }
1149 }
1150}
1151
1152#[derive(Debug, Clone)]
1153pub struct Phi4MMRotaryEmbedding {
1154 short_sin: Tensor,
1155 short_cos: Tensor,
1156 long_cos: Option<Tensor>,
1157 long_sin: Option<Tensor>,
1158 original_max_position_embeddings: usize,
1159}
1160
1161#[derive(Debug, Clone, Default, Deserialize, Serialize)]
1162#[serde(rename_all = "lowercase")]
1163pub enum Phi4MMScaledRopeType {
1164 #[serde(alias = "longrope")]
1165 LongRope,
1166 #[default]
1167 Default,
1168}
1169
1170#[derive(Debug, Clone, Deserialize, Serialize)]
1171pub struct Phi4MMRopeScalingConfig {
1172 short_factor: Option<Vec<f64>>,
1173 long_factor: Option<Vec<f64>>,
1174 #[serde(rename = "type")]
1175 scaling_type: Phi4MMScaledRopeType,
1176}
1177
1178impl Phi4MMRotaryEmbedding {
1179 fn new_unscaled(cfg: &Phi4MMConfig, dtype: DType, dev: &Device) -> Result<Self> {
1180 let max_seq_len = cfg.max_position_embeddings;
1181 let dim = (cfg.head_dim() as f64 * cfg.partial_rotary_factor) as usize;
1182
1183 let inv_freq: Vec<_> = (0..dim)
1184 .step_by(2)
1185 .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
1186 .collect();
1187 let inv_freq_len = inv_freq.len();
1188 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
1189 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
1190 .to_dtype(DType::F32)?
1191 .reshape((max_seq_len, 1))?;
1192 let freqs = t.matmul(&inv_freq)?;
1193 let sin = freqs.sin()?.to_dtype(dtype)?;
1194 let cos = freqs.cos()?.to_dtype(dtype)?;
1195 Ok(Self {
1196 short_cos: cos,
1197 short_sin: sin,
1198 long_cos: None,
1199 long_sin: None,
1200 original_max_position_embeddings: cfg.original_max_position_embeddings,
1201 })
1202 }
1203
1204 #[allow(clippy::too_many_arguments)]
1205 fn new_longrope(
1206 short_factor: &[f64],
1207 long_factor: &[f64],
1208 cfg: &Phi4MMConfig,
1209 dtype: DType,
1210 dev: &Device,
1211 ) -> Result<Self> {
1212 let max_seq_len = cfg.max_position_embeddings;
1213 let dim = (cfg.head_dim() as f64 * cfg.partial_rotary_factor) as usize;
1214
1215 let scale =
1217 cfg.max_position_embeddings as f64 / cfg.original_max_position_embeddings as f64;
1218 let scaling_factor = if scale <= 1.0 {
1219 1.0
1220 } else {
1221 (1.0 + scale.ln() / (cfg.original_max_position_embeddings as f64).ln()).sqrt()
1222 };
1223
1224 let inv_freq_short: Vec<_> = (0..dim)
1226 .step_by(2)
1227 .enumerate()
1228 .map(|(k, i)| {
1229 1f32 / (short_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32
1230 })
1231 .collect();
1232 let inv_freq_len_short = inv_freq_short.len();
1233 let inv_freq_short = Tensor::from_vec(inv_freq_short, (1, inv_freq_len_short), dev)?;
1234 let t_short = Tensor::arange(0u32, max_seq_len as u32, dev)?
1235 .to_dtype(DType::F32)?
1236 .reshape((max_seq_len, 1))?;
1237 let freqs_short = t_short.matmul(&inv_freq_short)?;
1238 let sin_short = (freqs_short.sin()?.to_dtype(dtype)? * scaling_factor)?;
1239 let cos_short = (freqs_short.cos()?.to_dtype(dtype)? * scaling_factor)?;
1240
1241 let inv_freq_long: Vec<_> = (0..dim)
1243 .step_by(2)
1244 .enumerate()
1245 .map(|(k, i)| {
1246 1f32 / (long_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32
1247 })
1248 .collect();
1249 let inv_freq_len_long = inv_freq_long.len();
1250 let inv_freq_long = Tensor::from_vec(inv_freq_long, (1, inv_freq_len_long), dev)?;
1251 let t_long = Tensor::arange(0u32, max_seq_len as u32, dev)?
1252 .to_dtype(DType::F32)?
1253 .reshape((max_seq_len, 1))?;
1254 let freqs_long = t_long.matmul(&inv_freq_long)?;
1255 let sin_long = (freqs_long.sin()?.to_dtype(dtype)? * scaling_factor)?;
1256 let cos_long = (freqs_long.cos()?.to_dtype(dtype)? * scaling_factor)?;
1257
1258 Ok(Self {
1259 short_cos: cos_short,
1260 short_sin: sin_short,
1261 long_cos: Some(cos_long),
1262 long_sin: Some(sin_long),
1263 original_max_position_embeddings: cfg.original_max_position_embeddings,
1264 })
1265 }
1266
1267 pub fn new(dtype: DType, cfg: &Phi4MMConfig, dev: &Device) -> Result<Self> {
1268 match &cfg.rope_scaling {
1269 Some(Phi4MMRopeScalingConfig {
1270 scaling_type: Phi4MMScaledRopeType::LongRope,
1271 short_factor: Some(short_factor),
1272 long_factor: Some(long_factor),
1273 }) => Self::new_longrope(short_factor, long_factor, cfg, dtype, dev),
1274
1275 _ => Self::new_unscaled(cfg, dtype, dev),
1276 }
1277 }
1278
1279 fn get_long_or_short_sin_cos(&self, position_ids: &[usize]) -> (&Tensor, &Tensor) {
1281 if self.long_cos.is_none() {
1282 return (&self.short_sin, &self.short_cos);
1283 }
1284 let seq_len = position_ids.iter().max().unwrap() + 1;
1285 if seq_len > self.original_max_position_embeddings {
1286 (
1287 self.long_sin.as_ref().unwrap(),
1288 self.long_cos.as_ref().unwrap(),
1289 )
1290 } else {
1291 (&self.short_sin, &self.short_cos)
1292 }
1293 }
1294
1295 pub fn forward(
1296 &self,
1297 q: &Tensor,
1298 k: &Tensor,
1299 seqlen_offsets: &[usize],
1300 position_ids: &[usize],
1301 ) -> Result<(Tensor, Tensor)> {
1302 let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
1303 let (sin, cos) = self.get_long_or_short_sin_cos(position_ids);
1304
1305 let rot_dim = cos.dim(D::Minus1)? * 2;
1306 let q_rot = q.narrow(D::Minus1, 0, rot_dim)?;
1307 let q_pass = q.narrow(D::Minus1, rot_dim, q.dim(D::Minus1)? - rot_dim)?;
1308 let k_rot = k.narrow(D::Minus1, 0, rot_dim)?;
1309 let k_pass = k.narrow(D::Minus1, rot_dim, k.dim(D::Minus1)? - rot_dim)?;
1310
1311 let (q_rot, k_rot) = if seqlen_offsets.len() == 1 {
1312 let cos = cos.narrow(0, seqlen_offsets[0], seq_len)?;
1313 let sin = sin.narrow(0, seqlen_offsets[0], seq_len)?;
1314 let q_embed = candle_nn::rotary_emb::rope(&q_rot.contiguous()?, &cos, &sin)?;
1315 let k_embed = candle_nn::rotary_emb::rope(&k_rot.contiguous()?, &cos, &sin)?;
1316 (q_embed, k_embed)
1317 } else {
1318 let mut q_embeds = Vec::new();
1319 let mut k_embeds = Vec::new();
1320 for (i, offset) in seqlen_offsets.iter().enumerate() {
1321 let cos = cos.narrow(0, *offset, seq_len)?;
1322 let sin = sin.narrow(0, *offset, seq_len)?;
1323 let q_embed = candle_nn::rotary_emb::rope(
1324 &q_rot.i(i)?.unsqueeze(0)?.contiguous()?,
1325 &cos,
1326 &sin,
1327 )?;
1328 let k_embed = candle_nn::rotary_emb::rope(
1329 &k_rot.i(i)?.unsqueeze(0)?.contiguous()?,
1330 &cos,
1331 &sin,
1332 )?;
1333 q_embeds.push(q_embed);
1334 k_embeds.push(k_embed);
1335 }
1336 let q_rot = Tensor::cat(&q_embeds, 0)?;
1337 let k_rot = Tensor::cat(&k_embeds, 0)?;
1338 (q_rot, k_rot)
1339 };
1340
1341 Ok((
1342 Tensor::cat(&[q_rot, q_pass], D::Minus1)?.contiguous()?,
1343 Tensor::cat(&[k_rot, k_pass], D::Minus1)?.contiguous()?,
1344 ))
1345 }
1346}
1347
1348#[derive(Debug, Clone)]
1349pub struct Gemma3RotaryEmbedding(RotaryEmbedding);
1350
1351#[derive(Debug, Clone, Deserialize, Serialize)]
1352#[serde(rename_all = "lowercase")]
1353pub enum Gemmma3ScaledRopeType {
1354 #[serde(alias = "linear")]
1355 Linear,
1356}
1357
1358#[derive(Debug, Clone, Deserialize, Serialize)]
1359pub struct Gemma3RopeScalingConfig {
1360 factor: f64,
1361 rope_type: Gemmma3ScaledRopeType,
1362}
1363
1364impl Gemma3RotaryEmbedding {
1365 fn new_linear(
1366 cfg: &Gemma3TextConfig,
1367 factor: f64,
1368 is_gpt_neox: bool,
1369 dtype: DType,
1370 dev: &Device,
1371 ) -> Result<Self> {
1372 let max_seq_len = cfg.max_position_embeddings;
1373 let dim = cfg.head_dim;
1374
1375 let inv_freq: Vec<_> = (0..dim)
1376 .step_by(2)
1377 .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
1378 .collect();
1379 let inv_freq_len = inv_freq.len();
1380 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
1381 let inv_freq = (inv_freq / factor)?;
1382
1383 let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
1384 .to_dtype(DType::F32)?
1385 .reshape((max_seq_len, 1))?;
1386 let freqs = t.matmul(&inv_freq)?;
1387 let sin = freqs.sin()?.to_dtype(dtype)?;
1388 let cos = freqs.cos()?.to_dtype(dtype)?;
1389 Ok(Self(RotaryEmbedding {
1390 cos,
1391 sin,
1392 is_gpt_neox,
1393 }))
1394 }
1395
1396 pub fn new(
1397 is_gpt_neox: bool,
1398 dtype: DType,
1399 cfg: &Gemma3TextConfig,
1400 dev: &Device,
1401 ) -> Result<Self> {
1402 match &cfg.rope_scaling {
1403 Some(Gemma3RopeScalingConfig {
1404 rope_type: Gemmma3ScaledRopeType::Linear,
1405 factor,
1406 }) => Self::new_linear(cfg, *factor, is_gpt_neox, dtype, dev),
1407
1408 _ => Self::new_linear(cfg, 1.0, is_gpt_neox, dtype, dev),
1409 }
1410 }
1411
1412 pub fn forward(
1413 &self,
1414 q: &Tensor,
1415 k: &Tensor,
1416 seqlen_offsets: &[usize],
1417 ) -> Result<(Tensor, Tensor)> {
1418 self.0.forward(q, k, seqlen_offsets)
1419 }
1420}
1421
1422#[derive(Debug, Clone)]
1423pub struct QLinear {
1424 inner: QMatMul,
1425 bias: Option<Tensor>,
1426 dtype: DType,
1427}
1428
1429impl QLinear {
1430 pub fn new<R: std::io::Read + std::io::Seek>(
1431 ct: &mut Content<'_, R>,
1432 name: &str,
1433 device: &Device,
1434 ) -> Result<Self> {
1435 let w = ct.tensor(&format!("{name}.weight"), device)?;
1436 let b = ct.tensor(&format!("{name}.bias"), device)?;
1437 let inner = QMatMul::from_qtensor(w)?;
1438 let bias = b.dequantize(device)?;
1439 Ok(Self {
1440 inner,
1441 bias: Some(bias),
1442 dtype: DType::F32,
1443 })
1444 }
1445
1446 pub fn from_linear(linear: Linear) -> Self {
1447 Self {
1448 inner: QMatMul::Tensor(linear.weight().clone()),
1449 bias: linear.bias().cloned(),
1450 dtype: linear.weight().dtype(),
1451 }
1452 }
1453
1454 pub fn from_parts(w: Tensor, b: Option<Tensor>) -> Self {
1455 let dtype = w.dtype();
1456 Self {
1457 inner: QMatMul::Tensor(w),
1458 bias: b,
1459 dtype,
1460 }
1461 }
1462
1463 pub fn from_qparts(w: QTensor, b: Option<Tensor>) -> Self {
1464 if let Some(ref b) = b {
1465 assert_eq!(b.dtype(), DType::F32);
1466 }
1467 Self {
1468 inner: QMatMul::QTensor(Arc::new(w)),
1469 bias: b,
1470 dtype: DType::F32,
1471 }
1472 }
1473
1474 pub fn from_old_and_qmatmul(inner: QMatMul, old: &Self) -> Self {
1475 Self {
1476 inner,
1477 bias: old.bias.clone(),
1478 dtype: old.dtype,
1479 }
1480 }
1481
1482 pub fn inner(&mut self) -> &mut QMatMul {
1483 &mut self.inner
1484 }
1485
1486 pub fn inner_ref(&self) -> &QMatMul {
1487 &self.inner
1488 }
1489
1490 pub fn is_quant(&self) -> bool {
1491 matches!(self.inner, QMatMul::QTensor(_))
1492 }
1493
1494 pub fn bias(&self) -> Option<&Tensor> {
1495 self.bias.as_ref()
1496 }
1497
1498 pub fn bias_mut(&mut self) -> Option<&mut Tensor> {
1499 self.bias.as_mut()
1500 }
1501}
1502
1503impl Module for QLinear {
1504 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
1505 let xs = if self.is_quant() {
1506 xs.to_dtype(DType::F32)?
1507 } else {
1508 xs.clone()
1509 };
1510 if let Some(bias) = &self.bias {
1511 self.inner
1512 .forward(&xs)?
1513 .broadcast_add(bias)?
1514 .to_dtype(self.dtype)
1515 } else {
1516 self.inner.forward(&xs)?.to_dtype(self.dtype)
1517 }
1518 }
1519}
1520
1521#[derive(Debug, Clone)]
1522pub struct RotaryEmbedding {
1523 cos: Tensor,
1524 sin: Tensor,
1525 is_gpt_neox: bool,
1526}
1527
1528impl RotaryEmbedding {
1529 pub fn new(
1530 base: f32,
1531 head_dim: usize,
1532 max_position_embeddings: usize,
1533 device: &Device,
1534 is_gpt_neox: bool,
1535 dtype: DType,
1536 ) -> Result<Self> {
1537 let inv_freq: Vec<_> = (0..head_dim)
1538 .step_by(2)
1539 .map(|i| 1f32 / base.powf(i as f32 / head_dim as f32))
1540 .collect();
1541 let inv_freq_len = inv_freq.len();
1542 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), device)?;
1543 let t = Tensor::arange(0u32, max_position_embeddings as u32, device)?
1544 .to_dtype(DType::F32)?
1545 .reshape((max_position_embeddings, 1))?;
1546 let freqs = t.matmul(&inv_freq)?;
1547 let sin = freqs.sin()?.to_dtype(dtype)?;
1548 let cos = freqs.cos()?.to_dtype(dtype)?;
1549
1550 Ok(Self {
1551 cos,
1552 sin,
1553 is_gpt_neox,
1554 })
1555 }
1556
1557 pub fn new_partial(
1558 base: f32,
1559 rot_dim: usize,
1560 max_position_embeddings: usize,
1561 device: &Device,
1562 is_gpt_neox: bool,
1563 dtype: DType,
1564 ) -> Result<Self> {
1565 let inv_freq: Vec<_> = (0..rot_dim)
1566 .step_by(2)
1567 .map(|i| 1f32 / base.powf(i as f32 / rot_dim as f32))
1568 .collect();
1569 let inv_freq_len = inv_freq.len();
1570 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), device)?;
1571 let t = Tensor::arange(0u32, max_position_embeddings as u32, device)?
1572 .to_dtype(DType::F32)?
1573 .reshape((max_position_embeddings, 1))?;
1574 let freqs = t.matmul(&inv_freq)?;
1575 let sin = freqs.sin()?.to_dtype(dtype)?;
1576 let cos = freqs.cos()?.to_dtype(dtype)?;
1577
1578 Ok(Self {
1579 cos,
1580 sin,
1581 is_gpt_neox,
1582 })
1583 }
1584
1585 pub fn forward(
1586 &self,
1587 q: &Tensor,
1588 k: &Tensor,
1589 seqlen_offsets: &[usize],
1590 ) -> Result<(Tensor, Tensor)> {
1591 let (b_sz, qh, seq_len, n_embd) = q.dims4()?;
1592 let (_b_sz, kh, _seq_len, __n_embd) = k.dims4()?;
1593
1594 let rope = if self.is_gpt_neox {
1595 candle_nn::rotary_emb::rope
1596 } else {
1597 candle_nn::rotary_emb::rope_i
1598 };
1599
1600 if cfg!(feature = "cuda") {
1601 let (cos, sin) = if seqlen_offsets.len() == 1 {
1602 (
1603 self.cos.narrow(0, seqlen_offsets[0], seq_len)?,
1604 self.sin.narrow(0, seqlen_offsets[0], seq_len)?,
1605 )
1606 } else {
1607 let mut cos_s = Vec::new();
1608 let mut sin_s = Vec::new();
1609 for offset in seqlen_offsets {
1610 cos_s.push(self.cos.narrow(0, *offset, seq_len)?);
1611 sin_s.push(self.sin.narrow(0, *offset, seq_len)?);
1612 }
1613 (Tensor::cat(&cos_s, 0)?, Tensor::cat(&sin_s, 0)?)
1614 };
1615
1616 let q_embed = q.transpose(1, 2)?.flatten(0, 1)?;
1617 let k_embed = k.transpose(1, 2)?.flatten(0, 1)?;
1618 mistralrs_quant::rotary::apply_rotary_inplace(
1619 &q_embed,
1620 &k_embed,
1621 &cos,
1622 &sin,
1623 self.is_gpt_neox,
1624 )?;
1625 let mut q = q_embed
1626 .reshape((b_sz, seq_len, qh, n_embd))?
1627 .transpose(1, 2)?;
1628 let mut k = k_embed
1629 .reshape((b_sz, seq_len, kh, n_embd))?
1630 .transpose(1, 2)?;
1631 if !(cfg!(feature = "flash-attn") || cfg!(feature = "flash-attn-v3")) {
1632 q = q.contiguous()?;
1633 k = k.contiguous()?;
1634 }
1635 Ok((q, k))
1636 } else if seqlen_offsets.len() == 1 {
1637 let cos = self.cos.narrow(0, seqlen_offsets[0], seq_len)?;
1638 let sin = self.sin.narrow(0, seqlen_offsets[0], seq_len)?;
1639 let q_embed = rope(&q.contiguous()?, &cos, &sin)?;
1640 let k_embed = rope(&k.contiguous()?, &cos, &sin)?;
1641 Ok((q_embed, k_embed))
1642 } else {
1643 let mut q_embeds = Vec::new();
1644 let mut k_embeds = Vec::new();
1645 for (i, offset) in seqlen_offsets.iter().enumerate() {
1646 let cos = self.cos.narrow(0, *offset, seq_len)?;
1647 let sin = self.sin.narrow(0, *offset, seq_len)?;
1648 let q_embed = rope(&q.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
1649 let k_embed = rope(&k.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
1650 q_embeds.push(q_embed);
1651 k_embeds.push(k_embed);
1652 }
1653 Ok((Tensor::cat(&q_embeds, 0)?, Tensor::cat(&k_embeds, 0)?))
1654 }
1655 }
1656}
1657
1658#[derive(Debug, Clone, Copy, PartialEq, Deserialize, Serialize, Default)]
1659#[serde(rename_all = "lowercase")]
1660pub enum Activation {
1661 #[default]
1662 #[serde(alias = "gelu")]
1663 Gelu,
1664 #[serde(alias = "gelu_new")]
1665 NewGelu,
1666 Relu,
1667 Relu2,
1668 Relu6,
1669 Silu,
1670 Sigmoid,
1671 HardSigmoid,
1672 Swiglu,
1673 Swish,
1674 HardSwish,
1675 Elu(f64),
1676 LeakyRelu(f64),
1677 #[serde(alias = "gelu_pytorch_tanh")]
1678 GeluPytorchTanh,
1679 QuickGelu,
1680}
1681
1682impl Module for Activation {
1683 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
1684 match self {
1685 Self::Gelu => xs.gelu_erf(),
1686 Self::NewGelu => xs.gelu(),
1688 Self::Relu => xs.relu(),
1689 Self::Relu2 => xs.relu()?.sqr(),
1690 Self::Relu6 => xs.clamp(0f32, 6f32),
1691 Self::Silu => xs.silu(),
1692 Self::Sigmoid => candle_nn::ops::sigmoid(xs),
1693 Self::HardSigmoid => candle_nn::ops::hard_sigmoid(xs),
1694 Self::Swiglu => candle_nn::ops::swiglu(xs),
1695 Self::Swish => xs * candle_nn::ops::sigmoid(xs)?,
1696 Self::HardSwish => xs * candle_nn::ops::hard_sigmoid(xs)?,
1697 &Self::Elu(alpha) => xs.elu(alpha),
1698 &Self::LeakyRelu(negative_slope) => candle_nn::ops::leaky_relu(xs, negative_slope),
1699 Self::GeluPytorchTanh => xs.gelu(),
1700 Self::QuickGelu => xs * candle_nn::ops::sigmoid(&(xs * 1.702f64)?),
1701 }
1702 }
1703}
1704
1705impl TryInto<candle_nn::Activation> for Activation {
1706 type Error = candle_core::Error;
1707
1708 fn try_into(self) -> Result<candle_nn::Activation> {
1709 match self {
1710 Self::Gelu => Ok(candle_nn::Activation::Gelu),
1711 Self::Relu => Ok(candle_nn::Activation::Relu),
1712 Self::Silu => Ok(candle_nn::Activation::Silu),
1713 Self::NewGelu => Ok(candle_nn::Activation::NewGelu),
1714 Self::Relu2 => Ok(candle_nn::Activation::Relu2),
1715 Self::Relu6 => Ok(candle_nn::Activation::Relu6),
1716 Self::Sigmoid => Ok(candle_nn::Activation::Sigmoid),
1717 Self::HardSigmoid => Ok(candle_nn::Activation::HardSigmoid),
1718 Self::Swiglu => Ok(candle_nn::Activation::Swiglu),
1719 Self::Swish => Ok(candle_nn::Activation::Swish),
1720 Self::HardSwish => Ok(candle_nn::Activation::HardSwish),
1721 Self::Elu(x) => Ok(candle_nn::Activation::Elu(x)),
1722 Self::LeakyRelu(x) => Ok(candle_nn::Activation::LeakyRelu(x)),
1723 Self::GeluPytorchTanh => Ok(candle_nn::Activation::GeluPytorchTanh),
1724 Self::QuickGelu => candle_core::bail!("No mapping to candle_nn for QuickGelu"),
1725 }
1726 }
1727}
1728
1729#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1730pub struct Conv3dConfig {
1731 pub padding: usize,
1732 pub stride: usize,
1733 pub dilation: usize,
1734 pub groups: usize,
1735}
1736
1737impl Default for Conv3dConfig {
1738 fn default() -> Self {
1739 Self {
1740 padding: 0,
1741 stride: 1,
1742 dilation: 1,
1743 groups: 1,
1744 }
1745 }
1746}
1747
1748pub struct Conv3dNoBias {
1749 conv2d_1: Conv2d,
1750 conv2d_2: Conv2d,
1751}
1752
1753impl Conv3dNoBias {
1754 pub fn new(
1755 in_channels: usize,
1756 out_channels: usize,
1757 kernel_sizes: [usize; 3],
1758 cfg: Conv3dConfig,
1759 vb: ShardedVarBuilder,
1760 ) -> Result<Self> {
1761 let ws = vb.get(
1762 (
1763 out_channels,
1764 in_channels / cfg.groups,
1765 kernel_sizes[0],
1766 kernel_sizes[1],
1767 kernel_sizes[2],
1768 ),
1769 "weight",
1770 )?;
1771
1772 let w1 = ws.i((.., .., 0, .., ..))?;
1776 let w2 = ws.i((.., .., 1, .., ..))?;
1777
1778 let cfg = Conv2dConfig {
1779 padding: cfg.padding,
1780 stride: cfg.stride,
1781 dilation: cfg.dilation,
1782 groups: cfg.groups,
1783 };
1784
1785 Ok(Self {
1786 conv2d_1: Conv2d::new(w1.contiguous()?, None, cfg),
1787 conv2d_2: Conv2d::new(w2.contiguous()?, None, cfg),
1788 })
1789 }
1790}
1791
1792impl Module for Conv3dNoBias {
1793 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
1794 let xs1 = xs.i((.., .., 0, .., ..))?;
1795 let xs2 = xs.i((.., .., 1, .., ..))?;
1796
1797 (self.conv2d_1.forward(&xs1)? + self.conv2d_2.forward(&xs2)?)?.unsqueeze(2)
1798 }
1799}
1800
1801pub trait TensorInfExtend {
1802 fn is_inf(&self) -> Result<Self>
1803 where
1804 Self: Sized;
1805 fn any(&self) -> Result<bool>;
1806}
1807
1808impl TensorInfExtend for Tensor {
1809 fn is_inf(&self) -> Result<Self> {
1810 self.broadcast_eq(&Tensor::new(f32::INFINITY, self.device())?.to_dtype(self.dtype())?)
1811 }
1812
1813 fn any(&self) -> Result<bool> {
1814 let sum = self.sum_all()?;
1815 match self.dtype() {
1816 DType::U8 => Ok(sum.to_scalar::<u8>()? == 0),
1817 DType::U32 => Ok(sum.to_scalar::<u32>()? == 0),
1818 DType::I16 => Ok(sum.to_scalar::<i16>()? == 0),
1819 DType::I32 => Ok(sum.to_scalar::<i32>()? == 0),
1820 DType::I64 => Ok(sum.to_scalar::<i64>()? == 0),
1821 DType::F16 => Ok(sum.to_scalar::<half::f16>()? == half::f16::from_f32_const(0.)),
1822 DType::BF16 => Ok(sum.to_scalar::<half::bf16>()? == half::bf16::from_f32_const(0.)),
1823 DType::F32 => Ok(sum.to_scalar::<f32>()? == 0.),
1824 DType::F64 => Ok(sum.to_scalar::<f64>()? == 0.),
1825 DType::F8E4M3 => Ok(sum.to_scalar::<F8E4M3>()? == F8E4M3::ZERO),
1826 }
1827 }
1828}
1829
1830pub fn clamp_for_f16(xs: &Tensor) -> Result<Tensor> {
1831 let mut max = match xs.dtype() {
1832 DType::U8 => u8::MAX as f32 - 1000.,
1833 DType::U32 => u32::MAX as f32 - 1000.,
1834 DType::I16 => i16::MAX as f32 - 1000.,
1835 DType::I32 => i32::MAX as f32 - 1000.,
1836 DType::I64 => i64::MAX as f32 - 1000.,
1837 DType::F16 => half::f16::MAX.to_f32_const() - 1000.,
1838 DType::BF16 => half::bf16::MAX.to_f32_const() - 1000.,
1839 DType::F32 => f32::MAX - 1000.,
1840 DType::F64 => f64::MAX as f32 - 1000.,
1841 DType::F8E4M3 => F8E4M3::MAX.to_f32() - 1000.,
1842 };
1843 if xs.is_inf()?.any()? {
1844 max -= 1000.;
1845 }
1846 xs.clamp(-max, max)
1847}
1848
1849pub struct FloatInfo {
1850 pub min: f64,
1852 pub max: f64,
1854 pub eps: f64,
1856 pub dtype: DType,
1857}
1858
1859pub trait GetFloatInfo {
1860 fn finfo(&self) -> Result<FloatInfo>;
1861}
1862
1863impl GetFloatInfo for DType {
1864 fn finfo(&self) -> Result<FloatInfo> {
1865 let finfo = match self {
1866 Self::BF16 => FloatInfo {
1867 min: bf16::MIN.to_f64(),
1868 max: bf16::MAX.to_f64(),
1869 eps: bf16::EPSILON.to_f64(),
1870 dtype: DType::BF16,
1871 },
1872 Self::F16 => FloatInfo {
1873 min: f16::MIN.to_f64(),
1874 max: f16::MAX.to_f64(),
1875 eps: f16::EPSILON.to_f64(),
1876 dtype: DType::F16,
1877 },
1878 Self::F32 => FloatInfo {
1879 min: f32::MIN as f64,
1880 max: f32::MAX as f64,
1881 eps: f32::EPSILON as f64,
1882 dtype: DType::F32,
1883 },
1884 Self::F64 => FloatInfo {
1885 min: f64::MIN,
1886 max: f64::MAX,
1887 eps: f64::EPSILON,
1888 dtype: DType::F64,
1889 },
1890 Self::F8E4M3 => FloatInfo {
1891 min: F8E4M3::MIN.to_f64(),
1892 max: F8E4M3::MAX.to_f64(),
1893 eps: F8E4M3::EPSILON.to_f64(),
1894 dtype: DType::F8E4M3,
1895 },
1896 other => {
1897 candle_core::bail!("Expected a float type for `GetFloatInfo`, got {other:?}");
1898 }
1899 };
1900 Ok(finfo)
1901 }
1902}
1903
1904#[derive(Clone)]
1905pub struct Mlp {
1906 pub gate: Arc<dyn QuantMethod>,
1907 pub up: Arc<dyn QuantMethod>,
1908 pub down: Arc<dyn QuantMethod>,
1909 act: Activation,
1910 params: Vec<usize>,
1911}
1912
1913impl Mlp {
1914 pub fn new(
1915 vb: ShardedVarBuilder,
1916 hidden_size: usize,
1917 intermediate_size: usize,
1918 quantization_config: &Option<QuantizedConfig>,
1919 hidden_act: Activation,
1920 comm: &Arc<mistralrs_quant::Comm>,
1921 ) -> Result<Self> {
1922 Ok(Self {
1923 gate: ColumnParallelLayer::new(
1924 hidden_size,
1925 intermediate_size,
1926 quantization_config,
1927 false,
1928 comm,
1929 vb.pp("gate_proj"),
1930 )?,
1931 up: ColumnParallelLayer::new(
1932 hidden_size,
1933 intermediate_size,
1934 quantization_config,
1935 false,
1936 comm,
1937 vb.pp("up_proj"),
1938 )?,
1939 down: RowParallelLayer::new(
1940 intermediate_size,
1941 hidden_size,
1942 quantization_config,
1943 false,
1944 comm,
1945 vb.pp("down_proj"),
1946 )?,
1947 act: hidden_act,
1948 params: vec![hidden_size, intermediate_size],
1949 })
1950 }
1951
1952 pub fn replicate(
1953 params: &[usize],
1954 vb: ShardedVarBuilder,
1955 act: Activation,
1956 comm: &Arc<mistralrs_quant::Comm>,
1957 ) -> Result<Self> {
1958 Self::new(vb, params[0], params[1], &None, act, comm)
1959 }
1960
1961 pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
1962 let original_dtype = xs.dtype();
1963 let mut xs = xs.clone();
1964 if let Some(t) = self.gate.quantized_act_type() {
1965 xs = xs.to_dtype(t)?;
1966 }
1967 let lhs = self.gate.forward(&xs)?;
1968 let rhs = self.up.forward(&xs)?;
1969 let mut res = self.down.forward(&candle_nn::ops::mul_and_act(
1970 &lhs,
1971 &rhs,
1972 self.act.try_into()?,
1973 )?)?;
1974 if self.gate.quantized_act_type().is_some() {
1975 res = res.to_dtype(original_dtype)?;
1976 }
1977 Ok(res)
1978 }
1979}
1980
1981impl AnyMoeTrainableLayer for Mlp {}
1982
1983impl MlpLayer for Mlp {
1984 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
1985 let original_dtype = xs.dtype();
1986 let mut xs = xs.clone();
1987 if let Some(t) = self.gate.quantized_act_type() {
1988 xs = xs.to_dtype(t)?;
1989 }
1990 let lhs = MatMul.qmethod_matmul(&xs, &*self.gate)?;
1991 let rhs = MatMul.qmethod_matmul(&xs, &*self.up)?;
1992 let mut res = if matches!(
1993 self.act,
1994 Activation::Gelu | Activation::Silu | Activation::Relu
1995 ) {
1996 MatMul.qmethod_matmul(
1997 &candle_nn::ops::mul_and_act(&lhs, &rhs, self.act.try_into()?)?,
1998 &*self.down,
1999 )?
2000 } else {
2001 MatMul.qmethod_matmul(&(&lhs.apply(&self.act)? * &rhs)?, &*self.down)?
2002 };
2003 if self.gate.quantized_act_type().is_some() {
2004 res = res.to_dtype(original_dtype)?;
2005 }
2006 Ok(res)
2007 }
2008 fn get_isq_layers(&mut self) -> Vec<&mut Arc<dyn QuantMethod>> {
2009 vec![&mut self.gate, &mut self.up, &mut self.down]
2010 }
2011 fn clone(&self) -> Box<dyn MlpLayer> {
2012 Box::new(Clone::clone(self))
2013 }
2014 fn get_params(&self) -> &[usize] {
2015 &self.params
2016 }
2017 fn hidden_act(&self) -> Activation {
2018 self.act
2019 }
2020 fn new_added_delta(&self, deltas: Vec<Option<Tensor>>) -> Result<Box<dyn MlpLayer>> {
2022 let gate = if let Some(ref delta) = deltas[0] {
2023 self.gate.add_delta_w(delta)?
2024 } else {
2025 self.gate.clone()
2026 };
2027 let up = if let Some(ref delta) = deltas[1] {
2028 self.up.add_delta_w(delta)?
2029 } else {
2030 self.up.clone()
2031 };
2032 let down = if let Some(ref delta) = deltas[2] {
2033 self.down.add_delta_w(delta)?
2034 } else {
2035 self.down.clone()
2036 };
2037
2038 Ok(Box::new(Self {
2039 gate,
2040 up,
2041 down,
2042 act: self.act,
2043 params: self.params.clone(),
2044 }))
2045 }
2046
2047 fn dtype_device(&self) -> (DType, Device) {
2048 self.gate.dtype_and_device()
2049 }
2050}
2051
2052pub struct AvgPool2d {
2053 kernel_size: usize,
2054 stride: usize,
2055}
2056
2057impl AvgPool2d {
2058 pub fn new(kernel_size: usize, stride: usize) -> Self {
2059 Self {
2060 kernel_size,
2061 stride,
2062 }
2063 }
2064
2065 pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2066 xs.avg_pool2d_with_stride(self.kernel_size, self.stride)
2067 }
2068}
2069
2070pub struct ReflectionPad2d {
2077 padding: (usize, usize, usize, usize),
2078}
2079
2080impl ReflectionPad2d {
2081 pub fn new(padding: (usize, usize, usize, usize)) -> Self {
2082 Self { padding }
2083 }
2084}
2085
2086impl Module for ReflectionPad2d {
2087 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2088 let (pad_left, pad_right, pad_top, pad_bottom) = self.padding;
2089
2090 let (_n, _c, h, w) = xs.dims4()?;
2091
2092 let left_pad = if pad_left > 0 {
2095 let indices: Vec<i64> = (1..=pad_left as i64).rev().collect();
2097 Some(xs.index_select(&Tensor::new(indices, &Device::Cpu)?, 3)?)
2098 } else {
2099 None
2100 };
2101
2102 let right_pad = if pad_right > 0 {
2104 let start = w as i64 - 2;
2106 let indices: Vec<i64> = (0..pad_right as i64).map(|i| start - i).collect();
2107 Some(xs.index_select(&Tensor::new(indices, &Device::Cpu)?, 3)?)
2108 } else {
2109 None
2110 };
2111
2112 let x_padded_width = match (left_pad, right_pad) {
2114 (Some(l), Some(r)) => Tensor::cat(&[l, xs.clone(), r], 3)?,
2115 (Some(l), None) => Tensor::cat(&[l, xs.clone()], 3)?,
2116 (None, Some(r)) => Tensor::cat(&[xs.clone(), r], 3)?,
2117 (None, None) => xs.clone(),
2118 };
2119
2120 let top_pad = if pad_top > 0 {
2123 let indices: Vec<i64> = (1..=pad_top as i64).rev().collect();
2124 Some(x_padded_width.index_select(&Tensor::new(indices, &Device::Cpu)?, 2)?)
2125 } else {
2126 None
2127 };
2128
2129 let bottom_pad = if pad_bottom > 0 {
2131 let start = h as i64 - 2;
2132 let indices: Vec<i64> = (0..pad_bottom as i64).map(|i| start - i).collect();
2133 Some(x_padded_width.index_select(&Tensor::new(indices, &Device::Cpu)?, 2)?)
2134 } else {
2135 None
2136 };
2137
2138 let x_padded = match (top_pad, bottom_pad) {
2140 (Some(t), Some(b)) => Tensor::cat(&[t, x_padded_width, b], 2)?,
2141 (Some(t), None) => Tensor::cat(&[t, x_padded_width], 2)?,
2142 (None, Some(b)) => Tensor::cat(&[x_padded_width, b], 2)?,
2143 (None, None) => x_padded_width,
2144 };
2145
2146 Ok(x_padded)
2147 }
2148}
2149
2150pub struct ScaledEmbedding {
2151 scale: f64,
2152 embedding: Embedding,
2153}
2154
2155impl ScaledEmbedding {
2156 pub fn new(scale: f64, embedding: Embedding) -> Self {
2157 Self { scale, embedding }
2158 }
2159
2160 pub fn embeddings(&self) -> &Tensor {
2161 self.embedding.embeddings()
2162 }
2163}
2164
2165impl Module for ScaledEmbedding {
2166 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
2167 xs.apply(&self.embedding)? * self.scale
2168 }
2169}