use std::{collections::HashMap, sync::Arc};
use candle_core::{DType, Device, Result, Tensor};
use candle_nn::{Embedding, Module, VarBuilder};
use mistralrs_quant::{QuantMethod, QuantMethodConfig, UnquantLinear};
use crate::{
attention::SdpaParams,
device_map::DeviceMapper,
layers::{Activation, F32RmsNorm, Qwen2VLRotaryEmbedding, Sdpa},
paged_attention::{AttentionImplementation, ModelConfigMetadata},
pipeline::{
extract_logits, text_models_inputs_processor::FlashParams, EitherCache, IsqModel, KvCache,
NormalCache, NormalLoadingMetadata,
},
utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
};
use super::config::Config;
struct Mlp {
gate_proj: Arc<dyn QuantMethod>,
up_proj: Arc<dyn QuantMethod>,
down_proj: Arc<dyn QuantMethod>,
act_fn: Activation,
}
impl Mlp {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let hidden_sz = cfg.hidden_size;
let intermediate_sz = cfg.intermediate_size;
let gate_proj = mistralrs_quant::linear_no_bias(
hidden_sz,
intermediate_sz,
&cfg.quantization_config,
vb.pp("gate_proj"),
)?;
let up_proj = mistralrs_quant::linear_no_bias(
hidden_sz,
intermediate_sz,
&cfg.quantization_config,
vb.pp("up_proj"),
)?;
let down_proj = mistralrs_quant::linear_no_bias(
intermediate_sz,
hidden_sz,
&cfg.quantization_config,
vb.pp("down_proj"),
)?;
Ok(Self {
gate_proj,
up_proj,
down_proj,
act_fn: cfg.hidden_act,
})
}
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let original_dtype = xs.dtype();
let mut xs = xs.clone();
if let Some(t) = self.gate_proj.quantized_act_type() {
xs = xs.to_dtype(t)?;
}
let lhs = self.gate_proj.forward(&xs)?.apply(&self.act_fn)?;
let rhs = self.up_proj.forward(&xs)?;
self.down_proj
.forward(&(lhs * rhs)?)?
.to_dtype(original_dtype)
}
}
struct Attention {
q_proj: Arc<dyn QuantMethod>,
k_proj: Arc<dyn QuantMethod>,
v_proj: Arc<dyn QuantMethod>,
o_proj: Arc<dyn QuantMethod>,
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
rotary_emb: Arc<Qwen2VLRotaryEmbedding>,
sdpa_params: SdpaParams,
}
impl Attention {
fn new(rotary_emb: Arc<Qwen2VLRotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let hidden_sz = cfg.hidden_size;
let num_heads = cfg.num_attention_heads;
let num_kv_heads = cfg.num_key_value_heads;
let head_dim = hidden_sz / num_heads;
let q_proj = mistralrs_quant::linear(
hidden_sz,
num_heads * head_dim,
&cfg.quantization_config,
vb.pp("q_proj"),
)?;
let k_proj = mistralrs_quant::linear(
hidden_sz,
num_kv_heads * head_dim,
&cfg.quantization_config,
vb.pp("k_proj"),
)?;
let v_proj = mistralrs_quant::linear(
hidden_sz,
num_kv_heads * head_dim,
&cfg.quantization_config,
vb.pp("v_proj"),
)?;
let o_proj = mistralrs_quant::linear_no_bias(
num_heads * head_dim,
hidden_sz,
&cfg.quantization_config,
vb.pp("o_proj"),
)?;
Ok(Self {
q_proj,
k_proj,
v_proj,
o_proj,
num_heads,
num_kv_heads,
head_dim,
rotary_emb,
sdpa_params: SdpaParams {
n_kv_groups: num_heads / num_kv_heads,
use_flash_attn: false,
softcap: None,
softmax_scale: 1.0 / (head_dim as f32).sqrt(),
sliding_window: None,
},
})
}
#[allow(clippy::too_many_arguments)]
fn forward(
&self,
xs: &Tensor,
attention_mask: Option<&Tensor>,
cos_sin: &(Tensor, Tensor),
kv_cache: &mut KvCache,
flash_params: &FlashParams,
) -> Result<Tensor> {
let (b_sz, q_len, _) = xs.dims3()?;
let original_dtype = xs.dtype();
let mut xs = xs.clone();
if let Some(t) = self.q_proj.quantized_act_type() {
xs = xs.to_dtype(t)?;
}
let mut q = self.q_proj.forward(&xs)?;
let mut k = self.k_proj.forward(&xs)?;
let mut v = self.v_proj.forward(&xs)?;
if self.q_proj.quantized_act_type().is_some() {
q = q.to_dtype(original_dtype)?;
k = k.to_dtype(original_dtype)?;
v = v.to_dtype(original_dtype)?;
}
let (mut q, mut k, v) = if q_len != 1 {
let q = q
.reshape((b_sz, q_len, self.num_heads, self.head_dim))?
.transpose(1, 2)?;
let k = k
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
let v = v
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
(q, k, v)
} else {
let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
(q, k, v)
};
self.rotary_emb.forward(cos_sin, &mut q, &mut k)?;
let mut attn_output = {
let (k, v) = kv_cache.append(&k, &v)?;
Sdpa.run_attention(
&q.contiguous()?.to_dtype(DType::F32)?,
&k.contiguous()?.to_dtype(DType::F32)?,
&v.contiguous()?.to_dtype(DType::F32)?,
attention_mask
.map(|mask| mask.to_dtype(DType::F32).unwrap())
.as_ref(),
Some(flash_params),
&self.sdpa_params,
)?
.to_dtype(q.dtype())?
};
if let Some(t) = self.q_proj.quantized_act_type() {
attn_output = attn_output.to_dtype(t)?;
}
attn_output = if attention_mask.is_some() {
attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?
} else {
attn_output.reshape((b_sz, q_len, ()))?
};
let mut res = self.o_proj.forward(&attn_output)?;
if self.q_proj.quantized_act_type().is_some() {
res = res.to_dtype(original_dtype)?;
}
Ok(res)
}
}
pub struct DecoderLayer {
self_attn: Attention,
mlp: Mlp,
input_layernorm: F32RmsNorm,
post_attention_layernorm: F32RmsNorm,
}
impl DecoderLayer {
fn new(
rotary_emb: Arc<Qwen2VLRotaryEmbedding>,
cfg: &Config,
vb: VarBuilder,
mapper: &dyn DeviceMapper,
layer_idx: usize,
loading_isq: bool,
) -> Result<Self> {
let self_attn = Attention::new(
rotary_emb,
cfg,
mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
)?;
let mlp = Mlp::new(cfg, mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq))?;
let input_layernorm = F32RmsNorm::new(
cfg.hidden_size,
cfg.rms_norm_eps,
mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
)?;
let post_attention_layernorm = F32RmsNorm::new(
cfg.hidden_size,
cfg.rms_norm_eps,
mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
)?;
Ok(Self {
self_attn,
mlp,
input_layernorm,
post_attention_layernorm,
})
}
#[allow(clippy::too_many_arguments)]
fn forward(
&self,
xs: &Tensor,
attention_mask: Option<&Tensor>,
cos_sin: &(Tensor, Tensor),
kv_cache: &mut KvCache,
flash_params: &FlashParams,
) -> Result<Tensor> {
let residual = xs;
let xs = self.input_layernorm.forward(xs)?;
let xs = self
.self_attn
.forward(&xs, attention_mask, cos_sin, kv_cache, flash_params)?;
let xs = (xs + residual)?;
let residual = &xs;
let xs = self
.mlp
.forward(&xs.apply(&self.post_attention_layernorm)?)?;
residual + xs
}
}
pub struct Qwen2VLTextModel {
embed_tokens: Embedding,
pub(super) norm: F32RmsNorm,
layers: Vec<DecoderLayer>,
mapper: Box<dyn DeviceMapper + Send + Sync>,
lm_head: Arc<dyn QuantMethod>,
pub(super) cache: EitherCache,
pub(super) cfg: ModelConfigMetadata,
pub(super) device: Device,
pub(super) dtype: DType,
pub(super) max_seq_len: usize,
}
impl Qwen2VLTextModel {
pub fn new(
cfg: &Config,
vb: VarBuilder,
_is_gptx: bool,
normal_loading_metadata: NormalLoadingMetadata,
attention_mechanism: AttentionImplementation,
) -> Result<Self> {
if !matches!(attention_mechanism, AttentionImplementation::Eager) {
candle_core::bail!("Expected eager attention implementation");
}
let mapper = normal_loading_metadata.mapper;
let vb_m = vb.pp("model");
let embed_tokens = candle_nn::embedding(
cfg.vocab_size,
cfg.hidden_size,
mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
)?;
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
let head_dim = cfg.hidden_size / cfg.num_attention_heads;
let mut ropes = HashMap::new();
for layer_idx in 0..cfg.num_hidden_layers {
let device = mapper
.device_for(layer_idx, false)
.unwrap_or(&normal_loading_metadata.real_device);
ropes.insert(
device.location(),
Arc::new(Qwen2VLRotaryEmbedding::new(
cfg.rope_theta as f32,
head_dim,
device,
cfg.rope_scaling.mrope_section.clone(),
)?),
);
}
let vb_l = vb_m.pp("layers");
for layer_idx in
NiceProgressBar::<_, 'b'>(0..cfg.num_hidden_layers, "Loading repeating layers")
{
let device = mapper
.device_for(layer_idx, false)
.unwrap_or(&normal_loading_metadata.real_device);
let rotary_emb = ropes
.get(&device.location())
.expect("No RoPE for device location!")
.clone();
let layer = DecoderLayer::new(
rotary_emb.clone(),
cfg,
vb_l.pp(layer_idx),
&*mapper,
layer_idx,
normal_loading_metadata.loading_isq,
)?;
layers.push(layer)
}
let norm = F32RmsNorm::new(
cfg.hidden_size,
cfg.rms_norm_eps,
mapper.set_nm_device(vb_m.pp("norm"), false),
)?;
let lm_head = if !cfg.tie_word_embeddings {
mistralrs_quant::linear_no_bias(
cfg.hidden_size,
cfg.vocab_size,
&None,
mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
)?
} else {
Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
candle_nn::Linear::new(
mapper.cast_nm_device(
embed_tokens.embeddings(),
normal_loading_metadata.loading_isq,
)?,
None,
),
))?)
};
Ok(Self {
embed_tokens,
norm,
layers,
lm_head,
cache: EitherCache::Normal(NormalCache::new(
cfg.num_hidden_layers,
cfg.max_position_embeddings,
)),
max_seq_len: cfg.max_position_embeddings,
mapper,
cfg: ModelConfigMetadata {
num_layers: cfg.num_hidden_layers,
hidden_size: cfg.hidden_size,
num_kv_heads: cfg.num_key_value_heads,
num_attn_heads: cfg.num_attention_heads,
sliding_window: cfg.sliding_window,
k_head_dim: None,
v_head_dim: None,
},
device: normal_loading_metadata.real_device.clone(),
dtype: vb.dtype(),
})
}
pub fn embed_tokens(&self, input_ids: &Tensor) -> Result<Tensor> {
self.embed_tokens.forward(input_ids)
}
pub fn forward_embeds(
&self,
mut xs: Tensor,
attention_mask: Option<&Tensor>,
position_ids: &Tensor,
context_lens: Vec<(usize, usize)>,
flash_params: &FlashParams,
) -> Result<Tensor> {
let cache = &mut self.cache.normal().0;
let cos_sin = self.layers[0]
.self_attn
.rotary_emb
.compute_cos_sin(position_ids, xs.dtype())?;
for (i, layer) in self.layers.iter().enumerate() {
xs = self.mapper.map(xs, i)?;
xs = layer.forward(
&xs,
attention_mask
.as_ref()
.map(|m| m.to_device(xs.device()).unwrap())
.as_ref(),
&cos_sin,
&mut cache[i],
flash_params,
)?
}
let xs = xs.to_device(&self.device)?;
let mut xs = xs.apply(&self.norm)?;
if let Some(t) = self.lm_head.quantized_act_type() {
xs = xs.to_dtype(t)?;
}
extract_logits(&self.lm_head.forward(&xs)?, context_lens)
}
}
impl IsqModel for Qwen2VLTextModel {
fn get_layers(
&mut self,
) -> (
Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
&dyn DeviceMapper,
) {
let mut tensors = Vec::new();
tensors.push((&mut self.lm_head, None));
for (i, layer) in self.layers.iter_mut().enumerate() {
tensors.push((&mut layer.self_attn.q_proj, Some(i)));
tensors.push((&mut layer.self_attn.k_proj, Some(i)));
tensors.push((&mut layer.self_attn.v_proj, Some(i)));
tensors.push((&mut layer.self_attn.o_proj, Some(i)));
tensors.push((&mut layer.mlp.gate_proj, Some(i)));
tensors.push((&mut layer.mlp.up_proj, Some(i)));
tensors.push((&mut layer.mlp.down_proj, Some(i)));
}
(tensors, &*self.mapper)
}
fn residual_tensors(&self) -> Vec<(String, Tensor)> {
let uvb = UnVarBuilder::new();
let uvb_m = uvb.pp("model");
uvb_m.pp("embed_tokens").add(&self.embed_tokens);
uvb_m.pp("norm").add(&self.norm);
for (layer_idx, layer) in self.layers.iter().enumerate() {
let uvb_l = uvb_m.pp("layers").pp(layer_idx);
uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
uvb_l
.pp("post_attention_layernorm")
.add(&layer.post_attention_layernorm);
}
uvb.to_safetensors()
}
}