1use std::{collections::HashMap, sync::Arc};
2
3use candle_core::{DType, Device, Result, Tensor};
4use candle_nn::{Embedding, Module};
5use mistralrs_quant::{
6 ColumnParallelLayer, QuantMethod, ReplicatedLayer, RowParallelLayer, ShardedVarBuilder,
7};
8
9use crate::{
10 attention::SdpaParams,
11 device_map::DeviceMapper,
12 layers::{self, Activation, F32RmsNorm, Qwen2VLRotaryEmbedding, Sdpa},
13 paged_attention::{AttentionImplementation, ModelConfigMetadata},
14 pipeline::{
15 extract_logits, text_models_inputs_processor::FlashParams, EitherCache, IsqModel, KvCache,
16 NormalCache, NormalLoadingMetadata,
17 },
18 utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
19};
20
21use super::config::Config;
22
23struct Mlp {
24 gate_proj: Arc<dyn QuantMethod>,
25 up_proj: Arc<dyn QuantMethod>,
26 down_proj: Arc<dyn QuantMethod>,
27 act_fn: Activation,
28}
29
30impl Mlp {
31 fn new(cfg: &Config, vb: ShardedVarBuilder, comm: &Arc<mistralrs_quant::Comm>) -> Result<Self> {
32 let hidden_sz = cfg.hidden_size;
33 let intermediate_sz = cfg.intermediate_size;
34 let gate_proj = ColumnParallelLayer::new(
35 hidden_sz,
36 intermediate_sz,
37 &cfg.quantization_config,
38 false,
39 comm,
40 vb.pp("gate_proj"),
41 )?;
42 let up_proj = ColumnParallelLayer::new(
43 hidden_sz,
44 intermediate_sz,
45 &cfg.quantization_config,
46 false,
47 comm,
48 vb.pp("up_proj"),
49 )?;
50 let down_proj = RowParallelLayer::new(
51 intermediate_sz,
52 hidden_sz,
53 &cfg.quantization_config,
54 false,
55 comm,
56 vb.pp("down_proj"),
57 )?;
58 Ok(Self {
59 gate_proj,
60 up_proj,
61 down_proj,
62 act_fn: cfg.hidden_act,
63 })
64 }
65
66 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
67 let original_dtype = xs.dtype();
68 let mut xs = xs.clone();
69 if let Some(t) = self.gate_proj.quantized_act_type() {
70 xs = xs.to_dtype(t)?;
71 }
72 let lhs = self.gate_proj.forward(&xs)?.apply(&self.act_fn)?;
73 let rhs = self.up_proj.forward(&xs)?;
74 self.down_proj
75 .forward(&(lhs * rhs)?)?
76 .to_dtype(original_dtype)
77 }
78}
79
80struct Attention {
81 q_proj: Arc<dyn QuantMethod>,
82 k_proj: Arc<dyn QuantMethod>,
83 v_proj: Arc<dyn QuantMethod>,
84 o_proj: Arc<dyn QuantMethod>,
85 num_heads: usize,
86 num_kv_heads: usize,
87 head_dim: usize,
88 rotary_emb: Arc<Qwen2VLRotaryEmbedding>,
89 sdpa_params: SdpaParams,
90}
91
92impl Attention {
93 fn new(
94 rotary_emb: Arc<Qwen2VLRotaryEmbedding>,
95 cfg: &Config,
96 vb: ShardedVarBuilder,
97 comm: &Arc<mistralrs_quant::Comm>,
98 ) -> Result<Self> {
99 let hidden_sz = cfg.hidden_size;
100 let num_heads = cfg.num_attention_heads;
101 let num_kv_heads = cfg.num_key_value_heads;
102 let head_dim = hidden_sz / num_heads;
103 let q_proj = ColumnParallelLayer::new(
104 hidden_sz,
105 num_heads * head_dim,
106 &cfg.quantization_config,
107 true,
108 comm,
109 vb.pp("q_proj"),
110 )?;
111 let kv_shard = mistralrs_quant::compute_kv_shard(
112 cfg.num_key_value_heads,
113 cfg.hidden_size / cfg.num_attention_heads,
114 comm,
115 );
116 let k_proj = ColumnParallelLayer::new_with_shard(
117 hidden_sz,
118 num_kv_heads * head_dim,
119 &cfg.quantization_config,
120 true,
121 comm,
122 kv_shard,
123 vb.pp("k_proj"),
124 )?;
125 let v_proj = ColumnParallelLayer::new_with_shard(
126 hidden_sz,
127 num_kv_heads * head_dim,
128 &cfg.quantization_config,
129 true,
130 comm,
131 kv_shard,
132 vb.pp("v_proj"),
133 )?;
134 let o_proj = RowParallelLayer::new(
135 num_heads * head_dim,
136 hidden_sz,
137 &cfg.quantization_config,
138 false,
139 comm,
140 vb.pp("o_proj"),
141 )?;
142 Ok(Self {
143 q_proj,
144 k_proj,
145 v_proj,
146 o_proj,
147 num_heads: num_heads / comm.world_size(),
148 num_kv_heads: (num_kv_heads / comm.world_size()).max(1),
149 head_dim,
150 rotary_emb,
151 sdpa_params: SdpaParams {
152 n_kv_groups: mistralrs_quant::compute_n_kv_groups(
153 cfg.num_key_value_heads,
154 cfg.num_attention_heads,
155 comm,
156 ),
157 softcap: None,
158 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
159 sliding_window: None,
160 },
161 })
162 }
163
164 #[allow(clippy::too_many_arguments)]
165 fn forward(
166 &self,
167 xs: &Tensor,
168 attention_mask: Option<&Tensor>,
169 cos_sin: &(Tensor, Tensor),
170 kv_cache: &mut KvCache,
171 flash_params: &FlashParams,
172 ) -> Result<Tensor> {
173 let (b_sz, q_len, _) = xs.dims3()?;
174
175 let original_dtype = xs.dtype();
176 let mut xs = xs.clone();
177 if let Some(t) = self.q_proj.quantized_act_type() {
178 xs = xs.to_dtype(t)?;
179 }
180 let mut q = self.q_proj.forward(&xs)?;
181 let mut k = self.k_proj.forward(&xs)?;
182 let mut v = self.v_proj.forward(&xs)?;
183 if self.q_proj.quantized_act_type().is_some() {
184 q = q.to_dtype(original_dtype)?;
185 k = k.to_dtype(original_dtype)?;
186 v = v.to_dtype(original_dtype)?;
187 }
188
189 let (mut q, mut k, v) = if q_len != 1 {
190 let q = q
191 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
192 .transpose(1, 2)?;
193 let k = k
194 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
195 .transpose(1, 2)?;
196 let v = v
197 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
198 .transpose(1, 2)?;
199 (q, k, v)
200 } else {
201 let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
202 let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
203 let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
204 (q, k, v)
205 };
206
207 self.rotary_emb.forward(cos_sin, &mut q, &mut k)?;
208
209 let mut attn_output = {
210 let (k, v) = kv_cache.append(&k, &v)?;
211
212 Sdpa.run_attention(
213 &q.contiguous()?.to_dtype(DType::F32)?,
214 &k.contiguous()?.to_dtype(DType::F32)?,
215 &v.contiguous()?.to_dtype(DType::F32)?,
216 attention_mask
217 .map(|mask| mask.to_dtype(DType::F32).unwrap())
218 .as_ref(),
219 Some(flash_params),
220 &self.sdpa_params,
221 )?
222 .to_dtype(q.dtype())?
223 };
224
225 if let Some(t) = self.q_proj.quantized_act_type() {
226 attn_output = attn_output.to_dtype(t)?;
227 }
228 attn_output = if attention_mask.is_some() {
229 attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?
230 } else {
231 attn_output.reshape((b_sz, q_len, ()))?
232 };
233 let mut res = self.o_proj.forward(&attn_output)?;
234 if self.q_proj.quantized_act_type().is_some() {
235 res = res.to_dtype(original_dtype)?;
236 }
237 Ok(res)
238 }
239}
240
241pub struct DecoderLayer {
242 self_attn: Attention,
243 mlp: Mlp,
244 input_layernorm: F32RmsNorm,
245 post_attention_layernorm: F32RmsNorm,
246}
247
248impl DecoderLayer {
249 fn new(
250 rotary_emb: Arc<Qwen2VLRotaryEmbedding>,
251 cfg: &Config,
252 vb: ShardedVarBuilder,
253 mapper: &dyn DeviceMapper,
254 layer_idx: usize,
255 loading_isq: bool,
256 comm: &Arc<mistralrs_quant::Comm>,
257 ) -> Result<Self> {
258 let self_attn = Attention::new(
259 rotary_emb,
260 cfg,
261 mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
262 comm,
263 )?;
264 let mlp = Mlp::new(
265 cfg,
266 mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq),
267 comm,
268 )?;
269 let input_layernorm = F32RmsNorm::new(
270 cfg.hidden_size,
271 cfg.rms_norm_eps,
272 mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
273 )?;
274 let post_attention_layernorm = F32RmsNorm::new(
275 cfg.hidden_size,
276 cfg.rms_norm_eps,
277 mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
278 )?;
279 Ok(Self {
280 self_attn,
281 mlp,
282 input_layernorm,
283 post_attention_layernorm,
284 })
285 }
286
287 #[allow(clippy::too_many_arguments)]
288 fn forward(
289 &self,
290 xs: &Tensor,
291 attention_mask: Option<&Tensor>,
292 cos_sin: &(Tensor, Tensor),
293 kv_cache: &mut KvCache,
294 flash_params: &FlashParams,
295 ) -> Result<Tensor> {
296 let residual = xs;
297 let xs = self.input_layernorm.forward(xs)?;
298 let xs = self
299 .self_attn
300 .forward(&xs, attention_mask, cos_sin, kv_cache, flash_params)?;
301 let xs = (xs + residual)?;
302 let residual = &xs;
303 let xs = self
304 .mlp
305 .forward(&xs.apply(&self.post_attention_layernorm)?)?;
306 residual + xs
307 }
308}
309
310pub struct Qwen2VLTextModel {
311 embed_tokens: Embedding,
312 pub(super) norm: F32RmsNorm,
313 layers: Vec<DecoderLayer>,
314 mapper: Box<dyn DeviceMapper + Send + Sync>,
315 lm_head: Arc<dyn QuantMethod>,
316 pub(super) cache: EitherCache,
317 pub(super) cfg: ModelConfigMetadata,
318 pub(super) device: Device,
319 pub(super) dtype: DType,
320 pub(super) max_seq_len: usize,
321}
322
323impl Qwen2VLTextModel {
324 pub fn new(
325 cfg: &Config,
326 vb: ShardedVarBuilder,
327 _is_gptx: bool,
328 normal_loading_metadata: NormalLoadingMetadata,
329 attention_mechanism: AttentionImplementation,
330 ) -> Result<Self> {
331 if !matches!(attention_mechanism, AttentionImplementation::Eager) {
332 candle_core::bail!("Expected eager attention implementation");
333 }
334 let mapper = normal_loading_metadata.mapper;
335 let vb_m = if vb.contains_tensor("language_model.model.embed_tokens.weight") {
337 vb.pp("language_model").pp("model")
338 } else {
339 vb.pp("model")
340 };
341
342 let embed_tokens = layers::embedding(
343 cfg.vocab_size,
344 cfg.hidden_size,
345 mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
346 &cfg.quantization_config,
347 )?;
348 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
349
350 let mut ropes = HashMap::new();
351 for layer_idx in 0..cfg.num_hidden_layers {
352 let device = mapper
353 .device_for(layer_idx, false)
354 .unwrap_or(&normal_loading_metadata.real_device);
355 ropes.insert(
356 device.location(),
357 Arc::new(Qwen2VLRotaryEmbedding::new(
358 cfg.rope_theta as f32,
359 head_dim,
360 device,
361 cfg.rope_scaling.mrope_section.clone(),
362 )?),
363 );
364 }
365
366 let vb_l = vb_m.pp("layers");
367 let layers = NiceProgressBar::<_, 'b'>(
368 0..cfg.num_hidden_layers,
369 "Loading repeating layers",
370 &normal_loading_metadata.multi_progress,
371 )
372 .par_iter_if_isq(|layer_idx| {
373 let device = mapper
374 .device_for(layer_idx, false)
375 .unwrap_or(&normal_loading_metadata.real_device);
376 let rotary_emb = ropes
377 .get(&device.location())
378 .expect("No RoPE for device location!")
379 .clone();
380 let comm = mapper.get_comm_for(layer_idx)?;
381 DecoderLayer::new(
382 rotary_emb.clone(),
383 cfg,
384 vb_l.pp(layer_idx),
385 &*mapper,
386 layer_idx,
387 normal_loading_metadata.loading_isq,
388 &comm,
389 )
390 })?;
391 let norm = F32RmsNorm::new(
392 cfg.hidden_size,
393 cfg.rms_norm_eps,
394 mapper.set_nm_device(vb_m.pp("norm"), false),
395 )?;
396 let lm_head = if !cfg.tie_word_embeddings {
397 ReplicatedLayer::new(
398 cfg.hidden_size,
399 cfg.vocab_size,
400 &cfg.quantization_config,
401 false,
402 mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
403 )?
404 } else {
405 ReplicatedLayer::from_linear(candle_nn::Linear::new(
406 mapper.cast_nm_device(
407 embed_tokens.embeddings(),
408 normal_loading_metadata.loading_isq,
409 )?,
410 None,
411 ))?
412 };
413 Ok(Self {
414 embed_tokens,
415 norm,
416 layers,
417 lm_head,
418 cache: EitherCache::Normal(NormalCache::new(
419 cfg.num_hidden_layers,
420 cfg.max_position_embeddings,
421 )),
422 max_seq_len: cfg.max_position_embeddings,
423 cfg: ModelConfigMetadata {
424 max_seq_len: cfg.max_position_embeddings,
425 num_layers: cfg.num_hidden_layers,
426 hidden_size: cfg.hidden_size,
427 num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
428 num_kv_heads: (cfg.num_key_value_heads / mapper.get_comm_for(0)?.world_size())
429 .max(1),
430 sliding_window: cfg.sliding_window,
431 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
432 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
433 },
434 device: normal_loading_metadata.real_device.clone(),
435 dtype: vb.dtype(),
436 mapper,
437 })
438 }
439
440 pub fn embed_tokens(&self, input_ids: &Tensor) -> Result<Tensor> {
441 self.embed_tokens.forward(input_ids)
442 }
443
444 pub fn forward_embeds(
445 &self,
446 mut xs: Tensor,
447 attention_mask: Option<&Tensor>,
448 position_ids: &Tensor,
449 context_lens: Vec<(usize, usize)>,
450 flash_params: &FlashParams,
451 ) -> Result<Tensor> {
452 let cache = &mut self.cache.normal().0;
453 let cos_sin = self.layers[0]
454 .self_attn
455 .rotary_emb
456 .compute_cos_sin(position_ids, xs.dtype())?;
457
458 for (i, layer) in self.layers.iter().enumerate() {
459 xs = self.mapper.map(xs, i)?;
460 xs = layer.forward(
461 &xs,
462 attention_mask
463 .as_ref()
464 .map(|m| m.to_device(xs.device()).unwrap())
465 .as_ref(),
466 &cos_sin,
467 &mut cache[i],
468 flash_params,
469 )?
470 }
471 let xs = xs.to_device(&self.device)?;
472 let mut xs = xs.apply(&self.norm)?;
473 if let Some(t) = self.lm_head.quantized_act_type() {
474 xs = xs.to_dtype(t)?;
475 }
476 extract_logits(&self.lm_head.forward(&xs)?, context_lens)
477 }
478}
479
480impl IsqModel for Qwen2VLTextModel {
481 fn get_layers(
482 &mut self,
483 ) -> (
484 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
485 &dyn DeviceMapper,
486 ) {
487 let mut tensors = Vec::new();
488 tensors.push((&mut self.lm_head, None));
489 for (i, layer) in self.layers.iter_mut().enumerate() {
490 tensors.push((&mut layer.self_attn.q_proj, Some(i)));
491 tensors.push((&mut layer.self_attn.k_proj, Some(i)));
492 tensors.push((&mut layer.self_attn.v_proj, Some(i)));
493 tensors.push((&mut layer.self_attn.o_proj, Some(i)));
494 tensors.push((&mut layer.mlp.gate_proj, Some(i)));
495 tensors.push((&mut layer.mlp.up_proj, Some(i)));
496 tensors.push((&mut layer.mlp.down_proj, Some(i)));
497 }
498 (tensors, &*self.mapper)
499 }
500
501 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
502 let uvb = UnVarBuilder::new();
503
504 let uvb_m = uvb.pp("model");
505 uvb_m.pp("embed_tokens").add(&self.embed_tokens);
506 uvb_m.pp("norm").add(&self.norm);
507
508 for (layer_idx, layer) in self.layers.iter().enumerate() {
509 let uvb_l = uvb_m.pp("layers").pp(layer_idx);
510 uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
511 uvb_l
512 .pp("post_attention_layernorm")
513 .add(&layer.post_attention_layernorm);
514 }
515
516 uvb.to_safetensors()
517 }
518}