1use std::{collections::HashMap, sync::Arc};
2
3use candle_core::{DType, Device, Result, Tensor};
4use candle_nn::{Embedding, Module};
5use mistralrs_quant::{
6 ColumnParallelLayer, QuantMethod, ReplicatedLayer, RowParallelLayer, ShardedVarBuilder,
7};
8
9use super::config::Config;
10use crate::{
11 attention::SdpaParams,
12 device_map::DeviceMapper,
13 layers::{self, Activation, F32RmsNorm, Qwen2_5VLRotaryEmbedding, Sdpa},
14 paged_attention::{AttentionImplementation, ModelConfigMetadata},
15 pipeline::{
16 extract_logits, text_models_inputs_processor::FlashParams, EitherCache, IsqModel, KvCache,
17 NormalCache, NormalLoadingMetadata,
18 },
19 utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
20};
21
22struct Mlp {
23 gate_proj: Arc<dyn QuantMethod>,
24 up_proj: Arc<dyn QuantMethod>,
25 down_proj: Arc<dyn QuantMethod>,
26 act_fn: Activation,
27}
28
29impl Mlp {
30 fn new(cfg: &Config, vb: ShardedVarBuilder, comm: &Arc<mistralrs_quant::Comm>) -> Result<Self> {
31 let hidden_sz = cfg.hidden_size;
32 let intermediate_sz = cfg.intermediate_size;
33 let gate_proj = ColumnParallelLayer::new(
34 hidden_sz,
35 intermediate_sz,
36 &cfg.quantization_config,
37 false,
38 comm,
39 vb.pp("gate_proj"),
40 )?;
41 let up_proj = ColumnParallelLayer::new(
42 hidden_sz,
43 intermediate_sz,
44 &cfg.quantization_config,
45 false,
46 comm,
47 vb.pp("up_proj"),
48 )?;
49 let down_proj = RowParallelLayer::new(
50 intermediate_sz,
51 hidden_sz,
52 &cfg.quantization_config,
53 false,
54 comm,
55 vb.pp("down_proj"),
56 )?;
57 Ok(Self {
58 gate_proj,
59 up_proj,
60 down_proj,
61 act_fn: cfg.hidden_act,
62 })
63 }
64
65 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
66 let original_dtype = xs.dtype();
67 let mut xs = xs.clone();
68 if let Some(t) = self.gate_proj.quantized_act_type() {
69 xs = xs.to_dtype(t)?;
70 }
71 let lhs = self.gate_proj.forward(&xs)?.apply(&self.act_fn)?;
72 let rhs = self.up_proj.forward(&xs)?;
73 self.down_proj
74 .forward(&(lhs * rhs)?)?
75 .to_dtype(original_dtype)
76 }
77}
78
79struct Attention {
80 q_proj: Arc<dyn QuantMethod>,
81 k_proj: Arc<dyn QuantMethod>,
82 v_proj: Arc<dyn QuantMethod>,
83 o_proj: Arc<dyn QuantMethod>,
84 num_heads: usize,
85 num_kv_heads: usize,
86 head_dim: usize,
87 rotary_emb: Arc<Qwen2_5VLRotaryEmbedding>,
88 sdpa_params: SdpaParams,
89}
90
91impl Attention {
92 fn new(
93 rotary_emb: Arc<Qwen2_5VLRotaryEmbedding>,
94 cfg: &Config,
95 vb: ShardedVarBuilder,
96 comm: &Arc<mistralrs_quant::Comm>,
97 ) -> Result<Self> {
98 let hidden_sz = cfg.hidden_size;
99 let num_heads = cfg.num_attention_heads;
100 let num_kv_heads = cfg.num_key_value_heads;
101 let head_dim = hidden_sz / num_heads;
102 let q_proj = ColumnParallelLayer::new(
103 hidden_sz,
104 num_heads * head_dim,
105 &cfg.quantization_config,
106 true,
107 comm,
108 vb.pp("q_proj"),
109 )?;
110 let kv_shard = mistralrs_quant::compute_kv_shard(
111 cfg.num_key_value_heads,
112 cfg.hidden_size / cfg.num_attention_heads,
113 comm,
114 );
115 let k_proj = ColumnParallelLayer::new_with_shard(
116 hidden_sz,
117 num_kv_heads * head_dim,
118 &cfg.quantization_config,
119 true,
120 comm,
121 kv_shard,
122 vb.pp("k_proj"),
123 )?;
124 let v_proj = ColumnParallelLayer::new_with_shard(
125 hidden_sz,
126 num_kv_heads * head_dim,
127 &cfg.quantization_config,
128 true,
129 comm,
130 kv_shard,
131 vb.pp("v_proj"),
132 )?;
133 let o_proj = RowParallelLayer::new(
134 num_heads * head_dim,
135 hidden_sz,
136 &cfg.quantization_config,
137 false,
138 comm,
139 vb.pp("o_proj"),
140 )?;
141 Ok(Self {
142 q_proj,
143 k_proj,
144 v_proj,
145 o_proj,
146 num_heads: num_heads / comm.world_size(),
147 num_kv_heads: (num_kv_heads / comm.world_size()).max(1),
148 head_dim,
149 rotary_emb,
150 sdpa_params: SdpaParams {
151 n_kv_groups: mistralrs_quant::compute_n_kv_groups(
152 cfg.num_key_value_heads,
153 cfg.num_attention_heads,
154 comm,
155 ),
156 use_flash_attn: false,
157 softcap: None,
158 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
159 sliding_window: None,
160 },
161 })
162 }
163
164 #[allow(clippy::too_many_arguments)]
165 fn forward(
166 &self,
167 xs: &Tensor,
168 attention_mask: Option<&Tensor>,
169 cos_sin: &(Tensor, Tensor),
170 kv_cache: &mut KvCache,
171 flash_params: &FlashParams,
172 ) -> Result<Tensor> {
173 let (b_sz, q_len, _) = xs.dims3()?;
174
175 let original_dtype = xs.dtype();
176 let mut xs = xs.clone();
177 if let Some(t) = self.q_proj.quantized_act_type() {
178 xs = xs.to_dtype(t)?;
179 }
180 let mut q = self.q_proj.forward(&xs)?;
181 let mut k = self.k_proj.forward(&xs)?;
182 let mut v = self.v_proj.forward(&xs)?;
183 if self.q_proj.quantized_act_type().is_some() {
184 q = q.to_dtype(original_dtype)?;
185 k = k.to_dtype(original_dtype)?;
186 v = v.to_dtype(original_dtype)?;
187 }
188
189 let (mut q, mut k, v) = if q_len != 1 {
190 let q = q
191 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
192 .transpose(1, 2)?;
193 let k = k
194 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
195 .transpose(1, 2)?;
196 let v = v
197 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
198 .transpose(1, 2)?;
199 (q, k, v)
200 } else {
201 let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
202 let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
203 let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
204 (q, k, v)
205 };
206
207 let cos_sin_relocated = &(
210 cos_sin.0.to_device(q.device())?,
211 cos_sin.1.to_device(q.device())?,
212 );
213 self.rotary_emb.forward(cos_sin_relocated, &mut q, &mut k)?;
214
215 let mut attn_output = {
216 let (k, v) = kv_cache.append(&k, &v)?;
217
218 Sdpa.run_attention(
219 &q.contiguous()?.to_dtype(DType::F32)?,
220 &k.contiguous()?.to_dtype(DType::F32)?,
221 &v.contiguous()?.to_dtype(DType::F32)?,
222 attention_mask
223 .map(|mask| mask.to_dtype(DType::F32).unwrap())
224 .as_ref(),
225 Some(flash_params),
226 &self.sdpa_params,
227 )?
228 .to_dtype(q.dtype())?
229 };
230
231 if let Some(t) = self.q_proj.quantized_act_type() {
232 attn_output = attn_output.to_dtype(t)?;
233 }
234 attn_output = if attention_mask.is_some() {
235 attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?
236 } else {
237 attn_output.reshape((b_sz, q_len, ()))?
238 };
239 let mut res = self.o_proj.forward(&attn_output)?;
240 if self.q_proj.quantized_act_type().is_some() {
241 res = res.to_dtype(original_dtype)?;
242 }
243 Ok(res)
244 }
245}
246
247pub struct DecoderLayer {
248 self_attn: Attention,
249 mlp: Mlp,
250 input_layernorm: F32RmsNorm,
251 post_attention_layernorm: F32RmsNorm,
252}
253
254impl DecoderLayer {
255 fn new(
256 rotary_emb: Arc<Qwen2_5VLRotaryEmbedding>,
257 cfg: &Config,
258 vb: ShardedVarBuilder,
259 mapper: &dyn DeviceMapper,
260 layer_idx: usize,
261 loading_isq: bool,
262 comm: &Arc<mistralrs_quant::Comm>,
263 ) -> Result<Self> {
264 let self_attn = Attention::new(
265 rotary_emb,
266 cfg,
267 mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
268 comm,
269 )?;
270 let mlp = Mlp::new(
271 cfg,
272 mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq),
273 comm,
274 )?;
275 let input_layernorm = F32RmsNorm::new(
276 cfg.hidden_size,
277 cfg.rms_norm_eps,
278 mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
279 )?;
280 let post_attention_layernorm = F32RmsNorm::new(
281 cfg.hidden_size,
282 cfg.rms_norm_eps,
283 mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
284 )?;
285 Ok(Self {
286 self_attn,
287 mlp,
288 input_layernorm,
289 post_attention_layernorm,
290 })
291 }
292
293 #[allow(clippy::too_many_arguments)]
294 fn forward(
295 &self,
296 xs: &Tensor,
297 attention_mask: Option<&Tensor>,
298 cos_sin: &(Tensor, Tensor),
299 kv_cache: &mut KvCache,
300 flash_params: &FlashParams,
301 ) -> Result<Tensor> {
302 let residual = xs;
303 let xs = self.input_layernorm.forward(xs)?;
304 let xs = self
305 .self_attn
306 .forward(&xs, attention_mask, cos_sin, kv_cache, flash_params)?;
307 let xs = (xs + residual)?;
308 let residual = &xs;
309 let xs = self
310 .mlp
311 .forward(&xs.apply(&self.post_attention_layernorm)?)?;
312 residual + xs
313 }
314}
315
316pub struct Qwen2_5VLTextModel {
317 embed_tokens: Embedding,
318 pub(super) norm: F32RmsNorm,
319 layers: Vec<DecoderLayer>,
320 mapper: Box<dyn DeviceMapper + Send + Sync>,
321 lm_head: Arc<dyn QuantMethod>,
322 pub(super) cache: EitherCache,
323 pub(super) cfg: ModelConfigMetadata,
324 pub(super) device: Device,
325 pub(super) dtype: DType,
326 pub(super) max_seq_len: usize,
327}
328
329impl Qwen2_5VLTextModel {
330 pub fn new(
331 cfg: &Config,
332 vb: ShardedVarBuilder,
333 _is_gptx: bool,
334 normal_loading_metadata: NormalLoadingMetadata,
335 attention_mechanism: AttentionImplementation,
336 ) -> Result<Self> {
337 if !matches!(attention_mechanism, AttentionImplementation::Eager) {
338 candle_core::bail!("Expected eager attention implementation");
339 }
340 let mapper = normal_loading_metadata.mapper;
341 let vb_m = vb.pp("model");
342
343 let embed_tokens = layers::embedding(
344 cfg.vocab_size,
345 cfg.hidden_size,
346 mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
347 &cfg.quantization_config,
348 )?;
349 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
350 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
351
352 let mut ropes = HashMap::new();
353 for layer_idx in 0..cfg.num_hidden_layers {
354 let device = mapper
355 .device_for(layer_idx, false)
356 .unwrap_or(&normal_loading_metadata.real_device);
357 ropes.insert(
358 device.location(),
359 Arc::new(Qwen2_5VLRotaryEmbedding::new(
360 cfg.rope_theta as f32,
361 head_dim,
362 device,
363 cfg.rope_scaling.mrope_section.clone(),
364 )?),
365 );
366 }
367
368 let vb_l = vb_m.pp("layers");
369 for layer_idx in NiceProgressBar::<_, 'b'>(
370 0..cfg.num_hidden_layers,
371 "Loading repeating layers",
372 &normal_loading_metadata.multi_progress,
373 ) {
374 let device = mapper
375 .device_for(layer_idx, false)
376 .unwrap_or(&normal_loading_metadata.real_device);
377 let rotary_emb = ropes
378 .get(&device.location())
379 .expect("No RoPE for device location!")
380 .clone();
381 let comm = mapper.get_comm_for(layer_idx)?;
382 let layer = DecoderLayer::new(
383 rotary_emb.clone(),
384 cfg,
385 vb_l.pp(layer_idx),
386 &*mapper,
387 layer_idx,
388 normal_loading_metadata.loading_isq,
389 &comm,
390 )?;
391 layers.push(layer)
392 }
393 let norm = F32RmsNorm::new(
394 cfg.hidden_size,
395 cfg.rms_norm_eps,
396 mapper.set_nm_device(vb_m.pp("norm"), false),
397 )?;
398 let lm_head = if !cfg.tie_word_embeddings {
399 ReplicatedLayer::new(
400 cfg.hidden_size,
401 cfg.vocab_size,
402 &None,
403 false,
404 mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
405 )?
406 } else {
407 ReplicatedLayer::from_linear(candle_nn::Linear::new(
408 mapper.cast_nm_device(
409 embed_tokens.embeddings(),
410 normal_loading_metadata.loading_isq,
411 )?,
412 None,
413 ))?
414 };
415 Ok(Self {
416 embed_tokens,
417 norm,
418 layers,
419 lm_head,
420 cache: EitherCache::Normal(NormalCache::new(
421 cfg.num_hidden_layers,
422 cfg.max_position_embeddings,
423 )),
424 max_seq_len: cfg.max_position_embeddings,
425 cfg: ModelConfigMetadata {
426 max_seq_len: cfg.max_position_embeddings,
427 num_layers: cfg.num_hidden_layers,
428 hidden_size: cfg.hidden_size,
429 num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
430 num_kv_heads: (cfg.num_key_value_heads / mapper.get_comm_for(0)?.world_size())
431 .max(1),
432 sliding_window: cfg.sliding_window,
433 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
434 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
435 },
436 device: normal_loading_metadata.real_device.clone(),
437 dtype: vb.dtype(),
438 mapper,
439 })
440 }
441
442 pub fn embed_tokens(&self, input_ids: &Tensor) -> Result<Tensor> {
443 self.embed_tokens.forward(input_ids)
444 }
445
446 pub fn forward_embeds(
447 &self,
448 mut xs: Tensor,
449 attention_mask: Option<&Tensor>,
450 position_ids: &Tensor,
451 context_lens: Vec<(usize, usize)>,
452 flash_params: &FlashParams,
453 ) -> Result<Tensor> {
454 let cache = &mut self.cache.normal().0;
455 let cos_sin = self.layers[0]
456 .self_attn
457 .rotary_emb
458 .compute_cos_sin(position_ids, xs.dtype())?;
459
460 for (i, layer) in self.layers.iter().enumerate() {
461 xs = self.mapper.map(xs, i)?;
462 xs = layer.forward(
463 &xs,
464 attention_mask
465 .as_ref()
466 .map(|m| m.to_device(xs.device()).unwrap())
467 .as_ref(),
468 &cos_sin,
469 &mut cache[i],
470 flash_params,
471 )?
472 }
473 let xs = xs.to_device(&self.device)?;
474 let mut xs = xs.apply(&self.norm)?;
475 if let Some(t) = self.lm_head.quantized_act_type() {
476 xs = xs.to_dtype(t)?;
477 }
478 extract_logits(&self.lm_head.forward(&xs)?, context_lens)
479 }
480}
481
482impl IsqModel for Qwen2_5VLTextModel {
483 fn get_layers(
484 &mut self,
485 ) -> (
486 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
487 &dyn DeviceMapper,
488 ) {
489 let mut tensors = Vec::new();
490 tensors.push((&mut self.lm_head, None));
491 for (i, layer) in self.layers.iter_mut().enumerate() {
492 tensors.push((&mut layer.self_attn.q_proj, Some(i)));
493 tensors.push((&mut layer.self_attn.k_proj, Some(i)));
494 tensors.push((&mut layer.self_attn.v_proj, Some(i)));
495 tensors.push((&mut layer.self_attn.o_proj, Some(i)));
496 tensors.push((&mut layer.mlp.gate_proj, Some(i)));
497 tensors.push((&mut layer.mlp.up_proj, Some(i)));
498 tensors.push((&mut layer.mlp.down_proj, Some(i)));
499 }
500 (tensors, &*self.mapper)
501 }
502
503 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
504 let uvb = UnVarBuilder::new();
505
506 let uvb_m = uvb.pp("model");
507 uvb_m.pp("embed_tokens").add(&self.embed_tokens);
508 uvb_m.pp("norm").add(&self.norm);
509
510 for (layer_idx, layer) in self.layers.iter().enumerate() {
511 let uvb_l = uvb_m.pp("layers").pp(layer_idx);
512 uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
513 uvb_l
514 .pp("post_attention_layernorm")
515 .add(&layer.post_attention_layernorm);
516 }
517
518 uvb.to_safetensors()
519 }
520}