1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::collections::HashMap;
4
5use crate::attention::SdpaParams;
6use crate::device_map::DeviceMapper;
7use crate::gguf::Content;
8use crate::layers::CausalMasker;
9use crate::layers::RmsNorm;
10use crate::layers::Sdpa;
11use crate::lora::get_lora_cfg;
12use crate::lora::LinearLayerLike;
13use crate::lora::LoraConfig;
14use crate::lora::Merge;
15use crate::lora::Ordering;
16use crate::lora::QLoraLinear;
17use crate::pipeline::extract_logits;
18use crate::pipeline::text_models_inputs_processor::FlashParams;
19use crate::pipeline::EitherCache;
20use crate::utils::progress::NiceProgressBar;
21use candle_core::quantized::QMatMul;
22use candle_core::quantized::QTensor;
23use candle_core::{DType, Device, IndexOp, Module, Result, Tensor, D};
24use candle_nn::Embedding;
25use indicatif::MultiProgress;
26use mistralrs_quant::ShardedVarBuilder;
27use tqdm::Iter;
28use tracing::info;
29
30use super::classifier::XLoraClassifier;
31use super::verify_sanity_adapters;
32use super::Cache;
33use super::NonGranularState;
34use super::ScalingsMaker;
35use super::XLoraConfig;
36use crate::models::quantized_phi3::PropsGGUF;
37use crate::utils::gguf_metadata::ContentMetadata;
38use crate::utils::model_config as ModelConfig;
39
40const SUPPORTED_LAYERS: [&str; 5] = [
41 "self_attn.qkv_proj",
42 "self_attn.o_proj",
43 "mlp.gate_up_proj",
44 "mlp.down_proj",
45 "lm_head",
46];
47
48#[derive(Debug)]
49struct Mlp {
50 ffn_up: QLoraLinear,
51 ffn_down: QLoraLinear,
52 i_size: usize,
53}
54
55impl Mlp {
56 fn forward(
57 &self,
58 xs: &Tensor,
59 scalings: Option<Tensor>,
60 global_scaling_weight: f64,
61 is_scaling_pass: Option<f64>,
62 ) -> Result<Tensor> {
63 let up_states = self.ffn_up.lora_forward(
64 xs,
65 scalings.clone(),
66 global_scaling_weight,
67 is_scaling_pass,
68 )?;
69 let gate = up_states.narrow(D::Minus1, 0, self.i_size)?;
70 let up_states = up_states.narrow(D::Minus1, self.i_size, self.i_size)?;
71 let up_states = (up_states * gate.silu()?)?;
72 self.ffn_down
73 .lora_forward(&up_states, scalings, global_scaling_weight, is_scaling_pass)
74 }
75}
76
77fn rms_norm(w: QTensor, eps: f64) -> Result<RmsNorm> {
78 let w = w.dequantize(&w.device())?;
79 let rms = RmsNorm::from_w(w, eps)?;
80 Ok(rms)
81}
82
83struct LayerWeights {
84 attn_qkv: QLoraLinear,
85 attn_output: QLoraLinear,
86 attn_norm: RmsNorm,
87 ffn_norm: RmsNorm,
88 mlp: Mlp,
89 n_head: usize,
90 n_kv_head: usize,
91 head_dim: usize,
92 cos: Tensor,
93 sin: Tensor,
94 sliding_window: usize,
95 sdpa_params: SdpaParams,
96 dtype: DType,
97}
98
99impl LayerWeights {
100 fn apply_rotary_emb(&self, xs: &Tensor, seqlen_offsets: &[usize]) -> Result<Tensor> {
101 let (_b_sz, _h, seq_len, _n_embd) = xs.dims4()?;
102 let mut outputs = Vec::new();
103 for (i, offset) in seqlen_offsets.iter().enumerate() {
104 let cos = self.cos.narrow(0, *offset, seq_len)?;
105 let sin = self.sin.narrow(0, *offset, seq_len)?;
106 outputs.push(candle_nn::rotary_emb::rope(
107 &xs.i(i)?.unsqueeze(0)?.contiguous()?,
108 &cos,
109 &sin,
110 )?);
111 }
112 Tensor::cat(&outputs, 0)
113 }
114
115 #[allow(clippy::too_many_arguments)]
116 fn forward_attn(
117 &self,
118 x: &Tensor,
119 mask: Option<&Tensor>,
120 seqlen_offsets: &[usize],
121 kv_cache: &mut Option<(Tensor, Tensor)>,
122 scalings: Option<Tensor>,
123 global_scaling_weight: f64,
124 is_scaling_pass: Option<f64>,
125 flash_params: &FlashParams,
126 ) -> Result<Tensor> {
127 let (b_sz, seq_len, n_embd) = x.dims3()?;
128 let qkv = self
129 .attn_qkv
130 .lora_forward(x, scalings.clone(), global_scaling_weight, is_scaling_pass)?
131 .to_dtype(self.dtype)?;
132
133 let query_pos = self.n_head * self.head_dim;
134 let q = qkv.narrow(D::Minus1, 0, query_pos)?;
135 let k = qkv.narrow(D::Minus1, query_pos, self.n_kv_head * self.head_dim)?;
136 let v = qkv.narrow(
137 D::Minus1,
138 query_pos + self.n_kv_head * self.head_dim,
139 self.n_kv_head * self.head_dim,
140 )?;
141
142 let (q, k, v) = if seq_len != 1 {
143 let q = q
144 .reshape((b_sz, seq_len, self.n_head, self.head_dim))?
145 .transpose(1, 2)?;
146 let k = k
147 .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
148 .transpose(1, 2)?;
149 let v = v
150 .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
151 .transpose(1, 2)?;
152 (q, k, v)
153 } else {
154 let q = q.reshape((b_sz, self.n_head, seq_len, self.head_dim))?;
155 let k = k.reshape((b_sz, self.n_kv_head, seq_len, self.head_dim))?;
156 let v = v.reshape((b_sz, self.n_kv_head, seq_len, self.head_dim))?;
157 (q, k, v)
158 };
159
160 let q = self.apply_rotary_emb(&q, seqlen_offsets)?.contiguous()?;
161 let k = self.apply_rotary_emb(&k, seqlen_offsets)?;
162
163 let (k, v, attn_mask) = Cache::update_kv_cache_sliding_window(
164 kv_cache,
165 k,
166 v,
167 mask,
168 Some(self.sliding_window),
169 true,
170 )?;
171
172 let y = Sdpa.run_attention(
173 &q,
174 &k,
175 &v,
176 attn_mask.as_ref(),
177 Some(flash_params),
178 &self.sdpa_params,
179 )?;
180
181 let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
182 let y = self.attn_output.lora_forward(
183 &y.to_dtype(x.dtype())?,
184 scalings,
185 global_scaling_weight,
186 is_scaling_pass,
187 )?;
188 Ok(y)
189 }
190}
191
192pub struct ModelWeights {
193 tok_embeddings: Embedding,
194 layers: Vec<LayerWeights>,
195 output_norm: RmsNorm,
196 output: QLoraLinear,
197 mapper: Option<Box<dyn DeviceMapper + Send + Sync>>,
198 pub device: Device,
199 pub cache: EitherCache,
200 pub max_seq_len: usize,
201 xlora_classifier: Option<XLoraClassifier>,
202 dtype: DType,
203}
204
205fn precomput_freqs_cis(
206 head_dim: usize,
207 freq_base: f32,
208 device: &Device,
209 context_window: usize,
210 dtype: DType,
211) -> Result<(Tensor, Tensor)> {
212 let theta: Vec<_> = (0..head_dim)
213 .step_by(2)
214 .map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32))
215 .collect();
216 let theta = Tensor::new(theta.as_slice(), device)?;
217 let idx_theta = Tensor::arange(0, context_window as u32, device)?
218 .to_dtype(DType::F32)?
219 .reshape((context_window, 1))?
220 .matmul(&theta.reshape((1, theta.elem_count()))?)?;
221 let cos = idx_theta.cos()?.to_dtype(dtype)?;
222 let sin = idx_theta.sin()?.to_dtype(dtype)?;
223 Ok((cos, sin))
224}
225
226impl ModelConfig::FromAdapterGGUF for ModelWeights {
227 #[allow(clippy::too_many_arguments)]
228 fn from_gguf<R: std::io::Seek + std::io::Read>(
229 mut ct: Content<'_, R>,
230 device: &Device,
231 lora_config: &[((String, String), LoraConfig)],
232 vb: &ShardedVarBuilder,
233 ordering: &Ordering,
234 xlora_config: Option<XLoraConfig>,
235 mapper: Box<dyn DeviceMapper + Send + Sync>,
236 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
237 dtype: DType,
238 ) -> Result<Self> {
239 verify_sanity_adapters(ordering, &SUPPORTED_LAYERS)?;
240
241 let metadata = ContentMetadata {
243 path_prefix: "phi3",
244 metadata: ct.get_metadata(),
245 };
246 let PropsGGUF {
247 head_count,
248 head_count_kv,
249 block_count,
250 embedding_length,
251 i_size,
252 rope_dim,
253 rms_eps,
254 context_window,
255 } = PropsGGUF::try_from(metadata).or_else(|err| candle_core::bail!("{err}"))?;
256
257 let (cos, sin) = precomput_freqs_cis(rope_dim, 10_000., device, context_window, dtype)?;
258
259 let tok_embeddings = ct.tensor("token_embd.weight", device)?;
260 let tok_embeddings = tok_embeddings.dequantize(device)?;
261 let output_norm = rms_norm(ct.tensor("output_norm.weight", device)?, rms_eps)?;
262 let output = ct.tensor("output.weight", device)?;
263 let mut layers = Vec::with_capacity(block_count);
264
265 let mut count = 0;
266 for layer_idx in NiceProgressBar::<_, 'b'>(
267 0..block_count,
268 "Loading repeating layers",
269 &MultiProgress::new(),
270 ) {
271 let prefix = format!("blk.{layer_idx}");
272 let device = mapper.device_for(layer_idx, false).unwrap_or(device);
273 let ffn_up = ct.tensor(&format!("{prefix}.ffn_up.weight"), device)?;
274 let ffn_down = ct.tensor(&format!("{prefix}.ffn_down.weight"), device)?;
275 let cfg_up = get_lora_cfg(&ffn_up);
276 let cfg_down = get_lora_cfg(&ffn_down);
277 let mlp = Mlp {
278 ffn_up: QLoraLinear::new(
279 QMatMul::from_qtensor(ffn_up)?,
280 &cfg_up,
281 lora_config,
282 vb,
283 ordering,
284 format!("{prefix}.mlp.gate_up_proj"),
285 &mut count,
286 preload_adapters,
287 )?,
288 ffn_down: QLoraLinear::new(
289 QMatMul::from_qtensor(ffn_down)?,
290 &cfg_down,
291 lora_config,
292 vb,
293 ordering,
294 format!("{prefix}.mlp.down_proj"),
295 &mut count,
296 preload_adapters,
297 )?,
298 i_size,
299 };
300 let attn_norm = rms_norm(
301 ct.tensor(&format!("{prefix}.attn_norm.weight"), device)?,
302 rms_eps,
303 )?;
304 let ffn_norm = rms_norm(
305 ct.tensor(&format!("{prefix}.ffn_norm.weight"), device)?,
306 rms_eps,
307 )?;
308 let qkv = ct.tensor(&format!("{prefix}.attn_qkv.weight"), device)?;
309 let output = ct.tensor(&format!("{prefix}.attn_output.weight"), device)?;
310 let cfg_qkv = get_lora_cfg(&qkv);
311 let cfg_out = get_lora_cfg(&output);
312 let head_dim = embedding_length / head_count;
313 layers.push(LayerWeights {
314 attn_qkv: QLoraLinear::new(
315 QMatMul::from_qtensor(qkv)?,
316 &cfg_qkv,
317 lora_config,
318 vb,
319 ordering,
320 format!("{prefix}.self_attn.qkv_proj"),
321 &mut count,
322 preload_adapters,
323 )?,
324 attn_output: QLoraLinear::new(
325 QMatMul::from_qtensor(output)?,
326 &cfg_out,
327 lora_config,
328 vb,
329 ordering,
330 format!("{prefix}.self_attn.o_proj"),
331 &mut count,
332 preload_adapters,
333 )?,
334 attn_norm,
335 ffn_norm,
336 mlp,
337 n_head: head_count,
338 n_kv_head: head_count_kv,
339 head_dim: embedding_length / head_count,
340 cos: cos.to_device(device)?,
341 sin: sin.to_device(device)?,
342 sliding_window: context_window,
343 sdpa_params: SdpaParams {
344 n_kv_groups: head_count / head_count_kv,
345 use_flash_attn: false,
346 softcap: None,
347 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
348 sliding_window: Some(context_window),
349 },
350 dtype,
351 })
352 }
353 if xlora_config.is_none() {
354 info!("Merging LoRA adapters.");
356 for layer in layers.iter_mut().tqdm() {
357 layer.attn_qkv.merge_weights()?;
358 layer.attn_output.merge_weights()?;
359 layer.mlp.ffn_down.merge_weights()?;
360 layer.mlp.ffn_up.merge_weights()?;
361 }
362 }
363 let output_cfg = get_lora_cfg(&output);
364 let output = QLoraLinear::new(
365 QMatMul::from_qtensor(output)?,
366 &output_cfg,
367 lora_config,
368 vb,
369 ordering,
370 "lm_head".to_string(),
371 &mut count,
372 preload_adapters,
373 )?;
374 if xlora_config.is_some() && output.is_lora() {
375 candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA.");
377 }
378 Ok(Self {
379 tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
380 layers,
381 output_norm,
382 output,
383 mapper: Some(mapper),
384 device: device.clone(),
385 cache: EitherCache::Full(Cache::new(block_count, true)),
386 max_seq_len: context_window,
387 xlora_classifier: xlora_config.map(|xlora_config| {
388 XLoraClassifier::new(xlora_config, count, lora_config.len(), vb.clone(), true)
389 .unwrap()
390 }),
391 dtype,
392 })
393 }
394}
395
396impl ModelWeights {
397 #[allow(clippy::too_many_arguments)]
398 pub fn inner_forward(
399 &self,
400 input_ids: &Tensor,
401 seqlen_offsets: &[usize],
402 scalings: Option<Tensor>,
403 is_full_pass: bool,
404 no_kv_cache: bool,
405 is_scaling_pass: Option<f64>,
406 flash_params: &FlashParams,
407 ) -> Result<Tensor> {
408 let mut xs = self.tok_embeddings.forward(input_ids)?;
409 let mut cache = if is_full_pass {
410 if no_kv_cache {
411 let mut new_cache = Vec::new();
412 for _ in 0..self.cache.full().xlora_lock().len() {
413 new_cache.push(None);
414 }
415
416 self.cache.full().xlora_lock().clone_from(&new_cache);
417 }
418 self.cache.full().xlora_lock()
419 } else {
420 self.cache.full().lock()
421 };
422 let mask = CausalMasker.make_sliding_window_causal_mask_matrix(
423 input_ids,
424 &*cache,
425 Some(self.max_seq_len),
426 self.dtype,
427 self.layers[0].n_head,
428 )?;
429 for (i, layer) in self.layers.iter().enumerate() {
430 if let Some(ref mapper) = self.mapper {
431 xs = mapper.map(xs, i)?;
432 }
433 let residual = &xs;
434 let ys = xs.apply(&layer.attn_norm)?;
435 let ys = layer.forward_attn(
436 &ys,
437 mask.as_ref()
438 .map(|m| m.to_device(xs.device()).unwrap())
439 .as_ref(),
440 seqlen_offsets,
441 &mut cache[i],
442 scalings.clone(),
443 self.xlora_classifier
444 .as_ref()
445 .map(|classifier| classifier.get_global_scaling_weight())
446 .unwrap_or(1.0),
447 is_scaling_pass,
448 flash_params,
449 )?;
450 let ys = (ys + residual)?;
451 let residual = &ys;
452 let ys = ys.apply(&layer.ffn_norm)?;
453 let ys = layer.mlp.forward(
454 &ys,
455 scalings.clone(),
456 self.xlora_classifier
457 .as_ref()
458 .map(|classifier| classifier.get_global_scaling_weight())
459 .unwrap_or(1.0),
460 is_scaling_pass,
461 )?;
462 xs = (ys + residual)?
463 }
464 let xs = xs.to_device(&self.device)?;
465 xs.apply(&self.output_norm)
466 }
467
468 #[allow(clippy::too_many_arguments)]
469 pub fn forward(
470 &self,
471 input_ids: &Tensor,
472 input_ids_full: &Tensor,
473 seqlen_offsets: &[usize],
474 seqlen_offsets_full: &[usize],
475 no_kv_cache: bool,
476 non_granular_state: &Option<NonGranularState>,
477 context_lens: Vec<(usize, usize)>,
478 flash_params: &FlashParams,
479 flash_params_full: &FlashParams,
480 ) -> Result<Tensor> {
481 if self.xlora_classifier.is_some() {
482 let scalings = self.get_scalings(
483 input_ids,
484 input_ids_full,
485 seqlen_offsets,
486 seqlen_offsets_full,
487 no_kv_cache,
488 non_granular_state,
489 &vec![usize::MAX; context_lens.len()],
490 flash_params,
491 flash_params_full,
492 )?;
493
494 if no_kv_cache {
495 extract_logits(
496 &self.output.lora_forward(
497 &self
498 .inner_forward(
499 input_ids_full,
500 seqlen_offsets_full,
501 Some(scalings),
502 true,
503 no_kv_cache,
504 None,
505 flash_params_full,
506 )?
507 .contiguous()?,
508 None,
509 1.0,
510 None,
511 )?,
512 context_lens,
513 )
514 } else {
515 extract_logits(
517 &self.output.lora_forward(
518 &self
519 .inner_forward(
520 input_ids,
521 seqlen_offsets,
522 Some(scalings),
523 true,
524 no_kv_cache,
525 None,
526 flash_params,
527 )?
528 .contiguous()?,
529 None,
530 1.0,
531 None,
532 )?,
533 context_lens,
534 )
535 }
536 } else {
537 extract_logits(
538 &self.output.lora_forward(
539 &self
540 .inner_forward(
541 input_ids,
542 seqlen_offsets,
543 None,
544 false,
545 no_kv_cache,
546 None,
547 flash_params,
548 )?
549 .contiguous()?,
550 None,
551 1.0,
552 None,
553 )?,
554 context_lens,
555 )
556 }
557 }
558}
559
560impl ScalingsMaker for ModelWeights {
561 fn dtype(&self) -> DType {
562 DType::F32 }
564 fn get_cache(&self) -> &EitherCache {
565 &self.cache
566 }
567 fn get_classifier(&self) -> &XLoraClassifier {
568 self.xlora_classifier.as_ref().unwrap()
569 }
570 fn forward(
571 &self,
572 input_ids: &Tensor,
573 seqlen_offsets: &[usize],
574 scalings: Tensor,
575 is_full_pass: bool,
576 no_kv_cache: bool,
577 is_scaling_pass: Option<f64>,
578 _context_lens: &[usize],
579 flash_params: &FlashParams,
580 ) -> Result<Tensor> {
581 self.inner_forward(
582 input_ids,
583 seqlen_offsets,
584 Some(scalings),
585 is_full_pass,
586 no_kv_cache,
587 is_scaling_pass,
588 flash_params,
589 )
590 }
591}