#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
use std::{
f32::consts::PI,
ops::Mul,
str::FromStr,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
};
use candle_core::{
quantized::{QMatMul, QTensor},
Context, DType, Device, IndexOp, Result, Tensor, D,
};
use candle_nn::{Conv2d, Conv2dConfig, Linear, Module, VarBuilder};
use mistralrs_quant::QuantMethod;
use serde::{Deserialize, Serialize};
pub use crate::attention::Sdpa;
pub use crate::layers_masker::CausalMasker;
pub use crate::layers_utils::repeat_kv;
use crate::{
gguf::Content,
models::llama,
ops::SplitOp,
vision_models::mllama::{MLlamaRopeScaling, MLlamaRopeType, MLlamaTextConfig},
INHIBIT_GEMM_F16,
};
#[derive(Debug, Clone)]
pub struct RmsNorm {
eps: f64,
weight: Tensor,
}
impl RmsNorm {
pub fn new(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
let inner = candle_nn::rms_norm_non_quant(size, eps, vb)?;
let w = inner.inner().weight().clone();
Ok(Self { eps, weight: w })
}
pub fn new_gemma(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
let inner = candle_nn::rms_norm_non_quant(size, eps, vb)?;
let w = (inner.inner().weight().clone() + 1.0)?;
Ok(Self { eps, weight: w })
}
pub fn undo_gemma(&self) -> Result<Self> {
Ok(Self {
eps: self.eps,
weight: (&self.weight - 1.0)?,
})
}
pub fn from_w(w: Tensor, eps: f64) -> Result<Self> {
Ok(Self { eps, weight: w })
}
pub fn weight(&self) -> &Tensor {
&self.weight
}
}
impl Module for RmsNorm {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
candle_nn::ops::rms_norm(&x.contiguous()?, &self.weight, self.eps as f32)
}
}
#[derive(Debug, Clone)]
pub struct F32RmsNorm {
w: Tensor,
eps: f64,
}
impl F32RmsNorm {
pub fn new(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
Ok(Self {
w: vb.get((size,), "weight")?,
eps,
})
}
pub fn weight(&self) -> &Tensor {
&self.w
}
}
impl Module for F32RmsNorm {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let initial_type = xs.dtype();
let mut xs = xs.to_dtype(DType::F32)?;
let var = xs.powf(2.)?.mean_keepdim(D::Minus1)?;
xs = xs.broadcast_mul(&(&var + self.eps)?.recip()?.sqrt()?)?;
xs.to_dtype(initial_type)?.broadcast_mul(&self.w)
}
}
#[derive(Debug, Clone)]
pub struct QRmsNorm {
eps: f64,
weight: Tensor,
}
impl QRmsNorm {
pub fn new(scale: QTensor, eps: f32) -> Result<Self> {
let scale = scale.dequantize(&scale.device())?;
Ok(Self {
eps: eps as f64,
weight: scale,
})
}
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
candle_nn::ops::rms_norm(&x.contiguous()?, &self.weight, self.eps as f32)
}
}
#[derive(Debug, Clone)]
pub struct PhiRotaryEmbedding {
short_sin: Tensor,
short_cos: Tensor,
long_cos: Option<Tensor>,
long_sin: Option<Tensor>,
original_max_position_embeddings: usize,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "lowercase")]
pub enum ScaledRopeType {
#[serde(alias = "su")]
#[serde(alias = "longrope")]
Su,
#[serde(alias = "yarn")]
Yarn,
}
impl FromStr for ScaledRopeType {
type Err = candle_core::Error;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
match s {
"su" | "longrope" => Ok(Self::Su),
"yarn" => Ok(Self::Yarn),
_ => Err(candle_core::Error::Msg(
"Expected either `su` or `yarn` scaled RoPE type.".to_string(),
)),
}
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(untagged)]
pub enum PhiRopeScalingConfig {
Classic {
short_factor: Vec<f64>,
long_factor: Vec<f64>,
#[serde(rename = "type")]
scaling_type: ScaledRopeType,
},
Scaled {
short_factor: Vec<f64>,
long_factor: Vec<f64>,
#[serde(rename = "type")]
scaling_type: ScaledRopeType,
long_mscale: f64,
short_mscale: f64,
},
}
pub struct PhiRopeConfig {
pub rope_scaling: Option<PhiRopeScalingConfig>,
pub max_position_embeddings: usize,
pub original_max_position_embeddings: usize,
pub rope_theta: f64,
pub head_dim: usize,
}
impl PhiRotaryEmbedding {
fn new_classic_scaled(
short_factor: &[f64],
long_factor: &[f64],
scaling_type: &ScaledRopeType,
cfg: &PhiRopeConfig,
dtype: DType,
dev: &Device,
) -> Result<Self> {
let max_seq_len = cfg.max_position_embeddings;
let dim = cfg.head_dim;
let scale =
cfg.max_position_embeddings as f64 / cfg.original_max_position_embeddings as f64;
let scaling_factor = if scale <= 1.0 {
1.0
} else {
match scaling_type {
ScaledRopeType::Su => {
(1.0 + scale.ln() / (cfg.original_max_position_embeddings as f64).ln()).sqrt()
}
ScaledRopeType::Yarn => 0.1 * scale.ln() + 1.0,
}
};
let inv_freq_long = (0..dim)
.step_by(2)
.enumerate()
.map(|(k, i)| {
(1f64 / (long_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64))) as f32
})
.collect::<Vec<_>>();
let inv_freq_short = (0..dim)
.step_by(2)
.enumerate()
.map(|(k, i)| {
(1f64 / (short_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64))) as f32
})
.collect::<Vec<_>>();
let inv_freq_len = inv_freq_long.len();
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
.to_dtype(DType::F32)?
.reshape((max_seq_len, 1))?;
let inv_freq_long = Tensor::from_vec(inv_freq_long, (1, inv_freq_len), dev)?;
let freqs_long = t.matmul(&inv_freq_long)?;
let long_sin = freqs_long.sin()?.mul(scaling_factor)?.to_dtype(dtype)?;
let long_cos = freqs_long.cos()?.mul(scaling_factor)?.to_dtype(dtype)?;
let inv_freq_short =
Tensor::from_vec(inv_freq_short, (1, inv_freq_len), dev)?.to_dtype(DType::F32)?;
let freqs_short = t.matmul(&inv_freq_short)?;
let short_sin = freqs_short.sin()?.mul(scaling_factor)?.to_dtype(dtype)?;
let short_cos = freqs_short.cos()?.mul(scaling_factor)?.to_dtype(dtype)?;
Ok(Self {
short_cos,
short_sin,
long_cos: Some(long_cos),
long_sin: Some(long_sin),
original_max_position_embeddings: cfg.original_max_position_embeddings,
})
}
fn new_unscaled(cfg: &PhiRopeConfig, dtype: DType, dev: &Device) -> Result<Self> {
let max_seq_len = cfg.max_position_embeddings;
let dim = cfg.head_dim;
let inv_freq: Vec<_> = (0..dim)
.step_by(2)
.map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
.collect();
let inv_freq_len = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
.to_dtype(DType::F32)?
.reshape((max_seq_len, 1))?;
let freqs = t.matmul(&inv_freq)?;
let sin = freqs.sin()?.to_dtype(dtype)?;
let cos = freqs.cos()?.to_dtype(dtype)?;
Ok(Self {
short_cos: cos,
short_sin: sin,
long_cos: None,
long_sin: None,
original_max_position_embeddings: cfg.original_max_position_embeddings,
})
}
#[allow(clippy::too_many_arguments)]
fn new_scaled(
short_factor: &[f64],
long_factor: &[f64],
scaling_type: &ScaledRopeType,
long_mscale: f64,
short_mscale: f64,
cfg: &PhiRopeConfig,
dtype: DType,
dev: &Device,
) -> Result<Self> {
let max_seq_len = cfg.max_position_embeddings;
let dim = cfg.head_dim;
if !matches!(scaling_type, ScaledRopeType::Su) {
candle_core::bail!("Scaled Phi3 RoPE (non-classic scaled, with mscales) must have type `su`/`longrope`.");
}
if short_factor.len() != dim / 2 {
candle_core::bail!(
"Misaligned length {}, expected {} for `su`/`longrope` short rescale factors",
short_factor.len(),
dim / 2
);
}
if long_factor.len() != dim / 2 {
candle_core::bail!(
"Misaligned length {}, expected {} for `su`/`longrope` long rescale factors",
long_factor.len(),
dim / 2
);
}
let inv_freq_short: Vec<_> = (0..dim)
.step_by(2)
.enumerate()
.map(|(k, i)| {
1f32 / (short_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32
})
.collect();
let inv_freq_len_short = inv_freq_short.len();
let inv_freq_short = Tensor::from_vec(inv_freq_short, (1, inv_freq_len_short), dev)?;
let t_short = Tensor::arange(0u32, max_seq_len as u32, dev)?
.to_dtype(DType::F32)?
.reshape((max_seq_len, 1))?;
let freqs_short = t_short.matmul(&inv_freq_short)?;
let sin_short = (freqs_short.sin()?.to_dtype(dtype)? * short_mscale)?;
let cos_short = (freqs_short.cos()?.to_dtype(dtype)? * short_mscale)?;
let inv_freq_long: Vec<_> = (0..dim)
.step_by(2)
.enumerate()
.map(|(k, i)| {
1f32 / (long_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32
})
.collect();
let inv_freq_len_long = inv_freq_long.len();
let inv_freq_long = Tensor::from_vec(inv_freq_long, (1, inv_freq_len_long), dev)?;
let t_long = Tensor::arange(0u32, max_seq_len as u32, dev)?
.to_dtype(DType::F32)?
.reshape((max_seq_len, 1))?;
let freqs_long = t_long.matmul(&inv_freq_long)?;
let sin_long = (freqs_long.sin()?.to_dtype(dtype)? * long_mscale)?;
let cos_long = (freqs_long.cos()?.to_dtype(dtype)? * long_mscale)?;
Ok(Self {
short_cos: cos_short,
short_sin: sin_short,
long_cos: Some(cos_long),
long_sin: Some(sin_long),
original_max_position_embeddings: cfg.original_max_position_embeddings,
})
}
pub fn new(dtype: DType, cfg: impl Into<PhiRopeConfig>, dev: &Device) -> Result<Self> {
let cfg: PhiRopeConfig = cfg.into();
match &cfg.rope_scaling {
Some(PhiRopeScalingConfig::Classic {
short_factor,
long_factor,
scaling_type,
}) => {
Self::new_classic_scaled(short_factor, long_factor, scaling_type, &cfg, dtype, dev)
}
Some(PhiRopeScalingConfig::Scaled {
short_factor,
long_factor,
scaling_type,
long_mscale,
short_mscale,
}) => Self::new_scaled(
short_factor,
long_factor,
scaling_type,
*long_mscale,
*short_mscale,
&cfg,
dtype,
dev,
),
None => Self::new_unscaled(&cfg, dtype, dev),
}
}
fn get_long_or_short_sin_cos(&self, position_ids: &[usize]) -> (&Tensor, &Tensor) {
if self.long_cos.is_none() {
return (&self.short_sin, &self.short_cos);
}
let seq_len = position_ids.iter().max().unwrap() + 1;
if seq_len > self.original_max_position_embeddings {
(
self.long_sin.as_ref().unwrap(),
self.long_cos.as_ref().unwrap(),
)
} else {
(&self.short_sin, &self.short_cos)
}
}
pub fn forward(
&self,
q: &Tensor,
k: &Tensor,
seqlen_offsets: &[usize],
position_ids: &[usize],
) -> Result<(Tensor, Tensor)> {
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
let mut q_embeds = Vec::new();
let mut k_embeds = Vec::new();
let (sin, cos) = self.get_long_or_short_sin_cos(position_ids);
for (i, offset) in seqlen_offsets.iter().enumerate() {
let cos = cos.narrow(0, *offset, seq_len)?;
let sin = sin.narrow(0, *offset, seq_len)?;
let q_embed =
candle_nn::rotary_emb::rope(&q.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
let k_embed =
candle_nn::rotary_emb::rope(&k.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
q_embeds.push(q_embed);
k_embeds.push(k_embed);
}
Ok((Tensor::cat(&q_embeds, 0)?, Tensor::cat(&k_embeds, 0)?))
}
}
#[derive(Debug, Clone)]
pub enum Llama3RotaryEmbedding {
Llama3 {
sin: Tensor,
cos: Tensor,
is_gptx: bool,
},
Default(RotaryEmbedding),
}
#[derive(Debug, Clone, Deserialize, Serialize, Default)]
pub enum Llama3RopeType {
#[serde(rename = "llama3")]
Llama3,
#[default]
#[serde(rename = "default")]
Default,
}
#[derive(Debug, Clone, Deserialize, Serialize, Default)]
pub struct Llama3RopeConfig {
pub factor: f32,
pub low_freq_factor: f32,
pub high_freq_factor: f32,
pub original_max_position_embeddings: usize,
pub rope_type: Llama3RopeType,
}
fn calculate_default_inv_freq(cfg: &llama::Config) -> Vec<f32> {
let head_dim = cfg.hidden_size / cfg.num_attention_heads;
(0..head_dim)
.step_by(2)
.map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
.collect()
}
impl Llama3RotaryEmbedding {
pub fn new_llama3(
dtype: DType,
cfg: &llama::Config,
dev: &Device,
is_gpt_neox: bool,
) -> Result<Self> {
match &cfg.rope_scaling {
None
| Some(Llama3RopeConfig {
rope_type: Llama3RopeType::Default,
..
}) => Ok(Self::Default(RotaryEmbedding::new(
cfg.rope_theta,
cfg.hidden_size / cfg.num_attention_heads,
cfg.max_position_embeddings,
dev,
is_gpt_neox,
dtype,
)?)),
Some(rope_scaling) => {
let low_freq_wavelen = rope_scaling.original_max_position_embeddings as f32
/ rope_scaling.low_freq_factor;
let high_freq_wavelen = rope_scaling.original_max_position_embeddings as f32
/ rope_scaling.high_freq_factor;
let inv_freq = calculate_default_inv_freq(cfg)
.into_iter()
.map(|freq| {
let wavelen = 2. * PI / freq;
if wavelen < high_freq_wavelen {
freq
} else if wavelen > low_freq_wavelen {
freq / rope_scaling.factor
} else {
let smooth = (rope_scaling.original_max_position_embeddings as f32
/ wavelen
- rope_scaling.low_freq_factor)
/ (rope_scaling.high_freq_factor - rope_scaling.low_freq_factor);
(1. - smooth) * freq / rope_scaling.factor + smooth * freq
}
})
.collect::<Vec<_>>();
let inv_freq_len = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
.to_dtype(DType::F32)?
.reshape((cfg.max_position_embeddings, 1))?;
let freqs = t.matmul(&inv_freq)?;
let sin = freqs.sin()?.to_dtype(dtype)?;
let cos = freqs.cos()?.to_dtype(dtype)?;
Ok(Self::Llama3 {
sin,
cos,
is_gptx: is_gpt_neox,
})
}
}
}
pub fn new_mllama3(
dtype: DType,
cfg: &MLlamaTextConfig,
dev: &Device,
is_gpt_neox: bool,
) -> Result<Self> {
match &cfg.rope_scaling {
None
| Some(MLlamaRopeScaling {
rope_type: MLlamaRopeType::Default,
..
}) => Ok(Self::Default(RotaryEmbedding::new(
cfg.rope_theta,
cfg.hidden_size / cfg.num_attention_heads,
cfg.max_position_embeddings,
dev,
is_gpt_neox,
dtype,
)?)),
Some(MLlamaRopeScaling {
rope_type: MLlamaRopeType::Llama3,
original_max_position_embeddings,
factor,
attention_factor: _,
beta_fast: _,
beta_slow: _,
short_factor: _,
long_factor: _,
low_freq_factor,
high_freq_factor,
}) => {
let factor = factor.context("MLlama Llama3 RoPE needs `factor` parameter.")?;
let low_freq_factor = low_freq_factor
.context("MLlama Llama3 RoPE needs `low_freq_factor` parameter.")?;
let high_freq_factor = high_freq_factor
.context("MLlama Llama3 RoPE needs `high_freq_factor` parameter.")?;
let low_freq_wavelen = *original_max_position_embeddings as f32 / low_freq_factor;
let high_freq_wavelen = *original_max_position_embeddings as f32 / high_freq_factor;
let head_dim = cfg.hidden_size / cfg.num_attention_heads;
let inv_freq = (0..head_dim)
.step_by(2)
.map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
.map(|freq| {
let wavelen = 2. * PI / freq;
if wavelen < high_freq_wavelen {
freq
} else if wavelen > low_freq_wavelen {
freq / factor
} else {
let smooth = (*original_max_position_embeddings as f32 / wavelen
- low_freq_factor)
/ (high_freq_factor - low_freq_factor);
(1. - smooth) * freq / factor + smooth * freq
}
})
.collect::<Vec<_>>();
let inv_freq_len = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
.to_dtype(DType::F32)?
.reshape((cfg.max_position_embeddings, 1))?;
let freqs = t.matmul(&inv_freq)?;
let sin = freqs.sin()?.to_dtype(dtype)?;
let cos = freqs.cos()?.to_dtype(dtype)?;
Ok(Self::Llama3 {
sin,
cos,
is_gptx: is_gpt_neox,
})
}
Some(MLlamaRopeScaling {
rope_type: other, ..
}) => {
candle_core::bail!(
"MLlama doesn't support any other RoPE type than `llama3`, got {other:?}"
)
}
}
}
pub fn forward(
&self,
positions: &[usize],
positions_kernel: &Tensor,
q: &mut Tensor,
k: &mut Tensor,
b_sz: usize,
) -> Result<()> {
match self {
Self::Llama3 { sin, cos, is_gptx } => {
let (b_sz_seq_len, h, n_embd) = q.dims3()?;
*q = q
.reshape((b_sz, b_sz_seq_len / b_sz, h, n_embd))?
.transpose(1, 2)?;
let (b_sz_seq_len, h, n_embd) = k.dims3()?;
*k = k
.reshape((b_sz, b_sz_seq_len / b_sz, h, n_embd))?
.transpose(1, 2)?;
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
let mut q_embeds = Vec::new();
let mut k_embeds = Vec::new();
for (i, offset) in positions.iter().enumerate() {
let cos = cos.narrow(0, *offset, seq_len)?;
let sin = sin.narrow(0, *offset, seq_len)?;
let rope = if *is_gptx {
candle_nn::rotary_emb::rope
} else {
candle_nn::rotary_emb::rope_i
};
let q_embed = rope(&q.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
let k_embed = rope(&k.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?;
q_embeds.push(q_embed);
k_embeds.push(k_embed);
}
*q = Tensor::cat(&q_embeds, 0)?;
*k = Tensor::cat(&k_embeds, 0)?;
Ok(())
}
Self::Default(rope) => rope.forward(positions, positions_kernel, q, k, b_sz),
}
}
}
#[derive(Debug, Clone)]
pub struct Qwen2VLRotaryEmbedding {
inv_freq: Tensor,
mrope_section: Vec<usize>,
}
impl Qwen2VLRotaryEmbedding {
pub fn new(
base: f32,
head_dim: usize,
device: &Device,
mrope_section: Vec<usize>,
) -> Result<Self> {
let inv_freq: Vec<_> = (0..head_dim)
.step_by(2)
.map(|i| 1f32 / base.powf(i as f32 / head_dim as f32))
.collect();
let inv_freq_len = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, (inv_freq_len,), device)?.to_dtype(DType::F32)?;
Ok(Self {
inv_freq,
mrope_section,
})
}
pub fn compute_cos_sin(&self, position_ids: &Tensor, dtype: DType) -> Result<(Tensor, Tensor)> {
let inv_freq_expanded =
self.inv_freq
.reshape((1, 1, (), 1))?
.repeat((3, position_ids.dim(1)?, 1, 1))?;
let position_ids_expanded = position_ids.unsqueeze(2)?;
let freqs = inv_freq_expanded
.matmul(&position_ids_expanded.to_dtype(inv_freq_expanded.dtype())?)?
.transpose(2, 3)?;
let cos = freqs.cos()?;
let sin = freqs.sin()?;
let cos = Tensor::cat(
&cos.split(&self.mrope_section, D::Minus1)?
.into_iter()
.enumerate()
.map(|(i, m)| m.i(i % 3))
.collect::<Result<Vec<_>>>()?,
D::Minus1,
)?
.squeeze(0)?
.to_dtype(dtype)?
.contiguous()?;
let sin = Tensor::cat(
&sin.split(&self.mrope_section, D::Minus1)?
.into_iter()
.enumerate()
.map(|(i, m)| m.i(i % 3))
.collect::<Result<Vec<_>>>()?,
D::Minus1,
)?
.squeeze(0)?
.to_dtype(dtype)?
.contiguous()?;
Ok((cos, sin))
}
pub fn forward(
&self,
(cos, sin): &(Tensor, Tensor),
q: &mut Tensor,
k: &mut Tensor,
) -> Result<()> {
*q = candle_nn::rotary_emb::rope(&q.contiguous()?, cos, sin)?;
*k = candle_nn::rotary_emb::rope(&k.contiguous()?, cos, sin)?;
Ok(())
}
}
pub struct MatMul;
pub(crate) static USE_MATMUL_VIA_F16: AtomicBool = AtomicBool::new(false);
pub(crate) fn set_use_matmul_via_f16(via_f16: bool) {
if !INHIBIT_GEMM_F16.load(Ordering::Relaxed) {
USE_MATMUL_VIA_F16.store(via_f16, Ordering::Relaxed)
}
}
pub fn get_use_matmul_via_f16() -> bool {
USE_MATMUL_VIA_F16.load(Ordering::Relaxed)
}
impl MatMul {
pub fn matmul(&self, a: &Tensor, b: &Tensor) -> Result<Tensor> {
if !get_use_matmul_via_f16() {
return a.matmul(b);
}
let original_dtype = a.dtype();
a.to_dtype(DType::F16)?
.matmul(&b.to_dtype(DType::F16)?)?
.to_dtype(original_dtype)
}
pub fn matmul_affine_div(&self, a: &Tensor, b: &Tensor, scale: f64) -> Result<Tensor> {
self.matmul(a, b)? / scale
}
pub fn qmatmul(&self, x: &Tensor, matmul: &QMatMul) -> Result<Tensor> {
if get_use_matmul_via_f16() {
matmul.forward_via_f16(x)
} else {
matmul.forward(x)
}
}
pub fn qmethod_matmul(&self, x: &Tensor, matmul: &dyn QuantMethod) -> Result<Tensor> {
if get_use_matmul_via_f16() {
matmul.forward_via_half(x)
} else {
matmul.forward(x)
}
}
}
#[derive(Debug, Clone)]
pub struct QLinear {
inner: QMatMul,
bias: Option<Tensor>,
dtype: DType,
}
impl QLinear {
pub fn new<R: std::io::Read + std::io::Seek>(
ct: &mut Content<'_, R>,
name: &str,
device: &Device,
) -> Result<Self> {
let w = ct.tensor(&format!("{name}.weight"), device)?;
let b = ct.tensor(&format!("{name}.bias"), device)?;
let inner = QMatMul::from_qtensor(w)?;
let bias = b.dequantize(device)?;
Ok(Self {
inner,
bias: Some(bias),
dtype: DType::F32,
})
}
pub fn from_linear(linear: Linear) -> Self {
Self {
inner: QMatMul::Tensor(linear.weight().clone()),
bias: linear.bias().cloned(),
dtype: linear.weight().dtype(),
}
}
pub fn from_parts(w: Tensor, b: Option<Tensor>) -> Self {
let dtype = w.dtype();
Self {
inner: QMatMul::Tensor(w),
bias: b,
dtype,
}
}
pub fn from_qparts(w: QTensor, b: Option<Tensor>) -> Self {
if let Some(ref b) = b {
assert_eq!(b.dtype(), DType::F32);
}
Self {
inner: QMatMul::QTensor(Arc::new(w)),
bias: b,
dtype: DType::F32,
}
}
pub fn from_old_and_qmatmul(inner: QMatMul, old: &Self) -> Self {
Self {
inner,
bias: old.bias.clone(),
dtype: old.dtype,
}
}
pub fn inner(&mut self) -> &mut QMatMul {
&mut self.inner
}
pub fn inner_ref(&self) -> &QMatMul {
&self.inner
}
pub fn is_quant(&self) -> bool {
matches!(self.inner, QMatMul::QTensor(_))
}
pub fn bias(&self) -> Option<&Tensor> {
self.bias.as_ref()
}
pub fn bias_mut(&mut self) -> Option<&mut Tensor> {
self.bias.as_mut()
}
}
impl Module for QLinear {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = if self.is_quant() {
xs.to_dtype(DType::F32)?
} else {
xs.clone()
};
let forward_fn = if !get_use_matmul_via_f16() {
QMatMul::forward
} else {
QMatMul::forward_via_f16
};
if let Some(bias) = &self.bias {
forward_fn(&self.inner, &xs)?
.broadcast_add(bias)?
.to_dtype(self.dtype)
} else {
forward_fn(&self.inner, &xs)?.to_dtype(self.dtype)
}
}
}
#[derive(Debug, Clone)]
pub struct RotaryEmbedding(candle_nn::RotaryEmbedding);
impl RotaryEmbedding {
pub fn new(
base: f32,
head_dim: usize,
max_position_embeddings: usize,
device: &Device,
is_gpt_neox: bool,
dtype: DType,
) -> Result<Self> {
Ok(Self(candle_nn::RotaryEmbedding::new(
base,
head_dim,
max_position_embeddings,
device,
is_gpt_neox,
dtype,
)?))
}
pub fn new_partial(
base: f32,
head_dim: usize,
rot_dim: usize,
max_position_embeddings: usize,
device: &Device,
is_gpt_neox: bool,
dtype: DType,
) -> Result<Self> {
Ok(Self(candle_nn::RotaryEmbedding::new_partial(
base,
head_dim,
rot_dim,
max_position_embeddings,
device,
is_gpt_neox,
dtype,
)?))
}
pub fn forward(
&self,
positions: &[usize],
positions_kernel: &Tensor,
q: &mut Tensor,
k: &mut Tensor,
b_sz: usize,
) -> Result<()> {
self.0.forward(positions, positions_kernel, q, k, b_sz)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Deserialize, Serialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum Activation {
#[default]
#[serde(alias = "gelu")]
Gelu,
#[serde(alias = "gelu_new")]
NewGelu,
Relu,
Relu2,
Relu6,
Silu,
Sigmoid,
HardSigmoid,
Swiglu,
Swish,
HardSwish,
Elu(f64),
LeakyRelu(f64),
#[serde(alias = "gelu_pytorch_tanh")]
GeluPytorchTanh,
QuickGelu,
}
impl Module for Activation {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
match self {
Self::Gelu => xs.gelu_erf(),
Self::NewGelu => xs.gelu(),
Self::Relu => xs.relu(),
Self::Relu2 => xs.relu()?.sqr(),
Self::Relu6 => xs.clamp(0f32, 6f32),
Self::Silu => xs.silu(),
Self::Sigmoid => candle_nn::ops::sigmoid(xs),
Self::HardSigmoid => candle_nn::ops::hard_sigmoid(xs),
Self::Swiglu => candle_nn::ops::swiglu(xs),
Self::Swish => xs * candle_nn::ops::sigmoid(xs)?,
Self::HardSwish => xs * candle_nn::ops::hard_sigmoid(xs)?,
&Self::Elu(alpha) => xs.elu(alpha),
&Self::LeakyRelu(negative_slope) => candle_nn::ops::leaky_relu(xs, negative_slope),
Self::GeluPytorchTanh => xs.gelu(),
Self::QuickGelu => xs * candle_nn::ops::sigmoid(&(xs * 1.702f64)?),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Conv3dConfig {
pub padding: usize,
pub stride: usize,
pub dilation: usize,
pub groups: usize,
}
impl Default for Conv3dConfig {
fn default() -> Self {
Self {
padding: 0,
stride: 1,
dilation: 1,
groups: 1,
}
}
}
pub struct Conv3dNoBias {
conv2d_1: Conv2d,
conv2d_2: Conv2d,
}
impl Conv3dNoBias {
pub fn new(
in_channels: usize,
out_channels: usize,
kernel_sizes: [usize; 3],
cfg: Conv3dConfig,
vb: VarBuilder,
) -> Result<Self> {
let ws = vb.get(
(
out_channels,
in_channels / cfg.groups,
kernel_sizes[0],
kernel_sizes[1],
kernel_sizes[2],
),
"weight",
)?;
let w1 = ws.i((.., .., 0, .., ..))?;
let w2 = ws.i((.., .., 1, .., ..))?;
let cfg = Conv2dConfig {
padding: cfg.padding,
stride: cfg.stride,
dilation: cfg.dilation,
groups: cfg.groups,
};
Ok(Self {
conv2d_1: Conv2d::new(w1.contiguous()?, None, cfg),
conv2d_2: Conv2d::new(w2.contiguous()?, None, cfg),
})
}
}
impl Module for Conv3dNoBias {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs1 = xs.i((.., .., 0, .., ..))?;
let xs2 = xs.i((.., .., 1, .., ..))?;
(self.conv2d_1.forward(&xs1)? + self.conv2d_2.forward(&xs2)?)?.unsqueeze(2)
}
}