1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::collections::HashMap;
4
5use crate::attention::SdpaParams;
6use crate::device_map::DeviceMapper;
7use crate::gguf::Content;
8use crate::layers::CausalMasker;
9use crate::layers::RmsNorm;
10use crate::layers::Sdpa;
11use crate::lora::get_lora_cfg;
12use crate::lora::AdapterSwapper;
13use crate::lora::LinearLayerLike;
14use crate::lora::LoraConfig;
15use crate::lora::Merge;
16use crate::lora::Ordering;
17use crate::lora::QLoraLinear;
18use crate::pipeline::extract_logits;
19use crate::pipeline::text_models_inputs_processor::FlashParams;
20use crate::pipeline::EitherCache;
21use crate::utils::progress::NiceProgressBar;
22use candle_core::quantized::QMatMul;
23use candle_core::quantized::QTensor;
24use candle_core::{DType, Device, IndexOp, Module, Result, Tensor, D};
25use candle_nn::Embedding;
26use indicatif::MultiProgress;
27use mistralrs_quant::ShardedVarBuilder;
28use tqdm::Iter;
29use tracing::info;
30
31use super::classifier::XLoraClassifier;
32use super::verify_sanity_adapters;
33use super::Cache;
34use super::NonGranularState;
35use super::ScalingsMaker;
36use super::XLoraConfig;
37use crate::models::quantized_phi3::PropsGGUF;
38use crate::utils::gguf_metadata::ContentMetadata;
39use crate::utils::model_config as ModelConfig;
40
41const SUPPORTED_LAYERS: [&str; 5] = [
42 "self_attn.qkv_proj",
43 "self_attn.o_proj",
44 "mlp.gate_up_proj",
45 "mlp.down_proj",
46 "lm_head",
47];
48
49#[derive(Debug)]
50struct Mlp {
51 ffn_up: QLoraLinear,
52 ffn_down: QLoraLinear,
53 i_size: usize,
54}
55
56impl Mlp {
57 fn forward(
58 &self,
59 xs: &Tensor,
60 scalings: Option<Tensor>,
61 global_scaling_weight: f64,
62 is_scaling_pass: Option<f64>,
63 ) -> Result<Tensor> {
64 let up_states = self.ffn_up.lora_forward(
65 xs,
66 scalings.clone(),
67 global_scaling_weight,
68 is_scaling_pass,
69 )?;
70 let gate = up_states.narrow(D::Minus1, 0, self.i_size)?;
71 let up_states = up_states.narrow(D::Minus1, self.i_size, self.i_size)?;
72 let up_states = (up_states * gate.silu()?)?;
73 self.ffn_down
74 .lora_forward(&up_states, scalings, global_scaling_weight, is_scaling_pass)
75 }
76}
77
78fn rms_norm(w: QTensor, eps: f64) -> Result<RmsNorm> {
79 let w = w.dequantize(&w.device())?;
80 let rms = RmsNorm::from_w(w, eps)?;
81 Ok(rms)
82}
83
84struct LayerWeights {
85 attn_qkv: QLoraLinear,
86 attn_output: QLoraLinear,
87 attn_norm: RmsNorm,
88 ffn_norm: RmsNorm,
89 mlp: Mlp,
90 n_head: usize,
91 n_kv_head: usize,
92 head_dim: usize,
93 cos: Tensor,
94 sin: Tensor,
95 sliding_window: usize,
96 sdpa_params: SdpaParams,
97 dtype: DType,
98}
99
100impl LayerWeights {
101 fn apply_rotary_emb(&self, xs: &Tensor, seqlen_offsets: &[usize]) -> Result<Tensor> {
102 let (_b_sz, _h, seq_len, _n_embd) = xs.dims4()?;
103 let mut outputs = Vec::new();
104 for (i, offset) in seqlen_offsets.iter().enumerate() {
105 let cos = self.cos.narrow(0, *offset, seq_len)?;
106 let sin = self.sin.narrow(0, *offset, seq_len)?;
107 outputs.push(candle_nn::rotary_emb::rope(
108 &xs.i(i)?.unsqueeze(0)?.contiguous()?,
109 &cos,
110 &sin,
111 )?);
112 }
113 Tensor::cat(&outputs, 0)
114 }
115
116 #[allow(clippy::too_many_arguments)]
117 fn forward_attn(
118 &self,
119 x: &Tensor,
120 mask: Option<&Tensor>,
121 seqlen_offsets: &[usize],
122 kv_cache: &mut Option<(Tensor, Tensor)>,
123 scalings: Option<Tensor>,
124 global_scaling_weight: f64,
125 is_scaling_pass: Option<f64>,
126 flash_params: &FlashParams,
127 ) -> Result<Tensor> {
128 let (b_sz, seq_len, n_embd) = x.dims3()?;
129 let qkv = self
130 .attn_qkv
131 .lora_forward(x, scalings.clone(), global_scaling_weight, is_scaling_pass)?
132 .to_dtype(self.dtype)?;
133
134 let query_pos = self.n_head * self.head_dim;
135 let q = qkv.narrow(D::Minus1, 0, query_pos)?;
136 let k = qkv.narrow(D::Minus1, query_pos, self.n_kv_head * self.head_dim)?;
137 let v = qkv.narrow(
138 D::Minus1,
139 query_pos + self.n_kv_head * self.head_dim,
140 self.n_kv_head * self.head_dim,
141 )?;
142
143 let (q, k, v) = if seq_len != 1 {
144 let q = q
145 .reshape((b_sz, seq_len, self.n_head, self.head_dim))?
146 .transpose(1, 2)?;
147 let k = k
148 .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
149 .transpose(1, 2)?;
150 let v = v
151 .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
152 .transpose(1, 2)?;
153 (q, k, v)
154 } else {
155 let q = q.reshape((b_sz, self.n_head, seq_len, self.head_dim))?;
156 let k = k.reshape((b_sz, self.n_kv_head, seq_len, self.head_dim))?;
157 let v = v.reshape((b_sz, self.n_kv_head, seq_len, self.head_dim))?;
158 (q, k, v)
159 };
160
161 let q = self.apply_rotary_emb(&q, seqlen_offsets)?.contiguous()?;
162 let k = self.apply_rotary_emb(&k, seqlen_offsets)?;
163
164 let (k, v, attn_mask) = Cache::update_kv_cache_sliding_window(
165 kv_cache,
166 k,
167 v,
168 mask,
169 Some(self.sliding_window),
170 true,
171 )?;
172
173 let y = Sdpa.run_attention(
174 &q,
175 &k,
176 &v,
177 attn_mask.as_ref(),
178 Some(flash_params),
179 &self.sdpa_params,
180 )?;
181
182 let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
183 let y = self.attn_output.lora_forward(
184 &y.to_dtype(x.dtype())?,
185 scalings,
186 global_scaling_weight,
187 is_scaling_pass,
188 )?;
189 Ok(y)
190 }
191}
192
193pub struct ModelWeights {
194 tok_embeddings: Embedding,
195 layers: Vec<LayerWeights>,
196 output_norm: RmsNorm,
197 output: QLoraLinear,
198 mapper: Option<Box<dyn DeviceMapper + Send + Sync>>,
199 pub device: Device,
200 pub cache: EitherCache,
201 pub max_seq_len: usize,
202 xlora_classifier: Option<XLoraClassifier>,
203 dtype: DType,
204}
205
206fn precomput_freqs_cis(
207 head_dim: usize,
208 freq_base: f32,
209 device: &Device,
210 context_window: usize,
211 dtype: DType,
212) -> Result<(Tensor, Tensor)> {
213 let theta: Vec<_> = (0..head_dim)
214 .step_by(2)
215 .map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32))
216 .collect();
217 let theta = Tensor::new(theta.as_slice(), device)?;
218 let idx_theta = Tensor::arange(0, context_window as u32, device)?
219 .to_dtype(DType::F32)?
220 .reshape((context_window, 1))?
221 .matmul(&theta.reshape((1, theta.elem_count()))?)?;
222 let cos = idx_theta.cos()?.to_dtype(dtype)?;
223 let sin = idx_theta.sin()?.to_dtype(dtype)?;
224 Ok((cos, sin))
225}
226
227impl ModelConfig::FromAdapterGGUF for ModelWeights {
228 #[allow(clippy::too_many_arguments)]
229 fn from_gguf<R: std::io::Seek + std::io::Read>(
230 mut ct: Content<'_, R>,
231 device: &Device,
232 lora_config: &[((String, String), LoraConfig)],
233 vb: &ShardedVarBuilder,
234 ordering: &Ordering,
235 xlora_config: Option<XLoraConfig>,
236 mapper: Box<dyn DeviceMapper + Send + Sync>,
237 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
238 dtype: DType,
239 ) -> Result<Self> {
240 verify_sanity_adapters(ordering, &SUPPORTED_LAYERS)?;
241
242 let metadata = ContentMetadata {
244 path_prefix: "phi3",
245 metadata: ct.get_metadata(),
246 };
247 let PropsGGUF {
248 head_count,
249 head_count_kv,
250 block_count,
251 embedding_length,
252 i_size,
253 rope_dim,
254 rms_eps,
255 context_window,
256 } = PropsGGUF::try_from(metadata).or_else(|err| candle_core::bail!("{err}"))?;
257
258 let (cos, sin) = precomput_freqs_cis(rope_dim, 10_000., device, context_window, dtype)?;
259
260 let tok_embeddings = ct.tensor("token_embd.weight", device)?;
261 let tok_embeddings = tok_embeddings.dequantize(device)?;
262 let output_norm = rms_norm(ct.tensor("output_norm.weight", device)?, rms_eps)?;
263 let output = ct.tensor("output.weight", device)?;
264 let mut layers = Vec::with_capacity(block_count);
265
266 let mut count = 0;
267 for layer_idx in NiceProgressBar::<_, 'b'>(
268 0..block_count,
269 "Loading repeating layers",
270 &MultiProgress::new(),
271 ) {
272 let prefix = format!("blk.{layer_idx}");
273 let device = mapper.device_for(layer_idx, false).unwrap_or(device);
274 let ffn_up = ct.tensor(&format!("{prefix}.ffn_up.weight"), device)?;
275 let ffn_down = ct.tensor(&format!("{prefix}.ffn_down.weight"), device)?;
276 let cfg_up = get_lora_cfg(&ffn_up);
277 let cfg_down = get_lora_cfg(&ffn_down);
278 let mlp = Mlp {
279 ffn_up: QLoraLinear::new(
280 QMatMul::from_qtensor(ffn_up)?,
281 &cfg_up,
282 lora_config,
283 vb,
284 ordering,
285 format!("{prefix}.mlp.gate_up_proj"),
286 &mut count,
287 preload_adapters,
288 )?,
289 ffn_down: QLoraLinear::new(
290 QMatMul::from_qtensor(ffn_down)?,
291 &cfg_down,
292 lora_config,
293 vb,
294 ordering,
295 format!("{prefix}.mlp.down_proj"),
296 &mut count,
297 preload_adapters,
298 )?,
299 i_size,
300 };
301 let attn_norm = rms_norm(
302 ct.tensor(&format!("{prefix}.attn_norm.weight"), device)?,
303 rms_eps,
304 )?;
305 let ffn_norm = rms_norm(
306 ct.tensor(&format!("{prefix}.ffn_norm.weight"), device)?,
307 rms_eps,
308 )?;
309 let qkv = ct.tensor(&format!("{prefix}.attn_qkv.weight"), device)?;
310 let output = ct.tensor(&format!("{prefix}.attn_output.weight"), device)?;
311 let cfg_qkv = get_lora_cfg(&qkv);
312 let cfg_out = get_lora_cfg(&output);
313 let head_dim = embedding_length / head_count;
314 layers.push(LayerWeights {
315 attn_qkv: QLoraLinear::new(
316 QMatMul::from_qtensor(qkv)?,
317 &cfg_qkv,
318 lora_config,
319 vb,
320 ordering,
321 format!("{prefix}.self_attn.qkv_proj"),
322 &mut count,
323 preload_adapters,
324 )?,
325 attn_output: QLoraLinear::new(
326 QMatMul::from_qtensor(output)?,
327 &cfg_out,
328 lora_config,
329 vb,
330 ordering,
331 format!("{prefix}.self_attn.o_proj"),
332 &mut count,
333 preload_adapters,
334 )?,
335 attn_norm,
336 ffn_norm,
337 mlp,
338 n_head: head_count,
339 n_kv_head: head_count_kv,
340 head_dim: embedding_length / head_count,
341 cos: cos.to_device(device)?,
342 sin: sin.to_device(device)?,
343 sliding_window: context_window,
344 sdpa_params: SdpaParams {
345 n_kv_groups: head_count / head_count_kv,
346 use_flash_attn: false,
347 softcap: None,
348 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
349 sliding_window: Some(context_window),
350 },
351 dtype,
352 })
353 }
354 if xlora_config.is_none() {
355 info!("Merging LoRA adapters.");
357 for layer in layers.iter_mut().tqdm() {
358 layer.attn_qkv.merge_weights()?;
359 layer.attn_output.merge_weights()?;
360 layer.mlp.ffn_down.merge_weights()?;
361 layer.mlp.ffn_up.merge_weights()?;
362 }
363 }
364 let output_cfg = get_lora_cfg(&output);
365 let output = QLoraLinear::new(
366 QMatMul::from_qtensor(output)?,
367 &output_cfg,
368 lora_config,
369 vb,
370 ordering,
371 "lm_head".to_string(),
372 &mut count,
373 preload_adapters,
374 )?;
375 if xlora_config.is_some() && output.is_lora() {
376 candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA.");
378 }
379 Ok(Self {
380 tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
381 layers,
382 output_norm,
383 output,
384 mapper: Some(mapper),
385 device: device.clone(),
386 cache: EitherCache::Full(Cache::new(block_count, true)),
387 max_seq_len: context_window,
388 xlora_classifier: xlora_config.map(|xlora_config| {
389 XLoraClassifier::new(xlora_config, count, lora_config.len(), vb.clone(), true)
390 .unwrap()
391 }),
392 dtype,
393 })
394 }
395}
396
397impl ModelWeights {
398 pub fn activate_adapters(&mut self, adapter_names: Vec<String>) -> Result<usize> {
399 if self.xlora_classifier.is_some() {
400 candle_core::bail!("Adapter activation is not supported for X-LoRA models as the adapter set must remain the same.");
401 }
402 let mut sum = 0;
403 for layer in self.layers.iter_mut() {
404 sum += layer.attn_qkv.activate(&adapter_names)?;
405 sum += layer.attn_output.activate(&adapter_names)?;
406 sum += layer.mlp.ffn_down.activate(&adapter_names)?;
407 sum += layer.mlp.ffn_up.activate(&adapter_names)?;
408 }
409 Ok(sum)
410 }
411
412 #[allow(clippy::too_many_arguments)]
413 pub fn inner_forward(
414 &self,
415 input_ids: &Tensor,
416 seqlen_offsets: &[usize],
417 scalings: Option<Tensor>,
418 is_full_pass: bool,
419 no_kv_cache: bool,
420 is_scaling_pass: Option<f64>,
421 flash_params: &FlashParams,
422 ) -> Result<Tensor> {
423 let mut xs = self.tok_embeddings.forward(input_ids)?;
424 let mut cache = if is_full_pass {
425 if no_kv_cache {
426 let mut new_cache = Vec::new();
427 for _ in 0..self.cache.full().xlora_lock().len() {
428 new_cache.push(None);
429 }
430
431 self.cache.full().xlora_lock().clone_from(&new_cache);
432 }
433 self.cache.full().xlora_lock()
434 } else {
435 self.cache.full().lock()
436 };
437 let mask = CausalMasker.make_sliding_window_causal_mask_matrix(
438 input_ids,
439 &*cache,
440 Some(self.max_seq_len),
441 self.dtype,
442 self.layers[0].n_head,
443 )?;
444 for (i, layer) in self.layers.iter().enumerate() {
445 if let Some(ref mapper) = self.mapper {
446 xs = mapper.map(xs, i)?;
447 }
448 let residual = &xs;
449 let ys = xs.apply(&layer.attn_norm)?;
450 let ys = layer.forward_attn(
451 &ys,
452 mask.as_ref()
453 .map(|m| m.to_device(xs.device()).unwrap())
454 .as_ref(),
455 seqlen_offsets,
456 &mut cache[i],
457 scalings.clone(),
458 self.xlora_classifier
459 .as_ref()
460 .map(|classifier| classifier.get_global_scaling_weight())
461 .unwrap_or(1.0),
462 is_scaling_pass,
463 flash_params,
464 )?;
465 let ys = (ys + residual)?;
466 let residual = &ys;
467 let ys = ys.apply(&layer.ffn_norm)?;
468 let ys = layer.mlp.forward(
469 &ys,
470 scalings.clone(),
471 self.xlora_classifier
472 .as_ref()
473 .map(|classifier| classifier.get_global_scaling_weight())
474 .unwrap_or(1.0),
475 is_scaling_pass,
476 )?;
477 xs = (ys + residual)?
478 }
479 let xs = xs.to_device(&self.device)?;
480 xs.apply(&self.output_norm)
481 }
482
483 #[allow(clippy::too_many_arguments)]
484 pub fn forward(
485 &self,
486 input_ids: &Tensor,
487 input_ids_full: &Tensor,
488 seqlen_offsets: &[usize],
489 seqlen_offsets_full: &[usize],
490 no_kv_cache: bool,
491 non_granular_state: &Option<NonGranularState>,
492 context_lens: Vec<(usize, usize)>,
493 flash_params: &FlashParams,
494 flash_params_full: &FlashParams,
495 ) -> Result<Tensor> {
496 if self.xlora_classifier.is_some() {
497 let scalings = self.get_scalings(
498 input_ids,
499 input_ids_full,
500 seqlen_offsets,
501 seqlen_offsets_full,
502 no_kv_cache,
503 non_granular_state,
504 &vec![usize::MAX; context_lens.len()],
505 flash_params,
506 flash_params_full,
507 )?;
508
509 if no_kv_cache {
510 extract_logits(
511 &self.output.lora_forward(
512 &self
513 .inner_forward(
514 input_ids_full,
515 seqlen_offsets_full,
516 Some(scalings),
517 true,
518 no_kv_cache,
519 None,
520 flash_params_full,
521 )?
522 .contiguous()?,
523 None,
524 1.0,
525 None,
526 )?,
527 context_lens,
528 )
529 } else {
530 extract_logits(
532 &self.output.lora_forward(
533 &self
534 .inner_forward(
535 input_ids,
536 seqlen_offsets,
537 Some(scalings),
538 true,
539 no_kv_cache,
540 None,
541 flash_params,
542 )?
543 .contiguous()?,
544 None,
545 1.0,
546 None,
547 )?,
548 context_lens,
549 )
550 }
551 } else {
552 extract_logits(
553 &self.output.lora_forward(
554 &self
555 .inner_forward(
556 input_ids,
557 seqlen_offsets,
558 None,
559 false,
560 no_kv_cache,
561 None,
562 flash_params,
563 )?
564 .contiguous()?,
565 None,
566 1.0,
567 None,
568 )?,
569 context_lens,
570 )
571 }
572 }
573}
574
575impl ScalingsMaker for ModelWeights {
576 fn dtype(&self) -> DType {
577 DType::F32 }
579 fn get_cache(&self) -> &EitherCache {
580 &self.cache
581 }
582 fn get_classifier(&self) -> &XLoraClassifier {
583 self.xlora_classifier.as_ref().unwrap()
584 }
585 fn forward(
586 &self,
587 input_ids: &Tensor,
588 seqlen_offsets: &[usize],
589 scalings: Tensor,
590 is_full_pass: bool,
591 no_kv_cache: bool,
592 is_scaling_pass: Option<f64>,
593 _context_lens: &[usize],
594 flash_params: &FlashParams,
595 ) -> Result<Tensor> {
596 self.inner_forward(
597 input_ids,
598 seqlen_offsets,
599 Some(scalings),
600 is_full_pass,
601 no_kv_cache,
602 is_scaling_pass,
603 flash_params,
604 )
605 }
606}