1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::collections::HashMap;
4
5use crate::attention::SdpaParams;
6use crate::device_map::DeviceMapper;
7use crate::gguf::Content;
8use crate::layers::CausalMasker;
9use crate::layers::RmsNorm;
10use crate::layers::Sdpa;
11use crate::lora::get_lora_cfg;
12use crate::lora::LinearLayerLike;
13use crate::lora::LoraConfig;
14use crate::lora::Merge;
15use crate::lora::Ordering;
16use crate::lora::QLoraLinear;
17use crate::pipeline::extract_logits;
18use crate::pipeline::text_models_inputs_processor::FlashParams;
19use crate::pipeline::EitherCache;
20use crate::utils::progress::NiceProgressBar;
21use candle_core::quantized::QMatMul;
22use candle_core::quantized::QTensor;
23use candle_core::{DType, Device, IndexOp, Module, Result, Tensor, D};
24use candle_nn::Embedding;
25use indicatif::MultiProgress;
26use mistralrs_quant::ShardedVarBuilder;
27use tqdm::Iter;
28use tracing::info;
29
30use super::classifier::XLoraClassifier;
31use super::verify_sanity_adapters;
32use super::Cache;
33use super::NonGranularState;
34use super::ScalingsMaker;
35use super::XLoraConfig;
36use crate::models::quantized_phi3::PropsGGUF;
37use crate::utils::gguf_metadata::ContentMetadata;
38use crate::utils::model_config as ModelConfig;
39
40const SUPPORTED_LAYERS: [&str; 5] = [
41 "self_attn.qkv_proj",
42 "self_attn.o_proj",
43 "mlp.gate_up_proj",
44 "mlp.down_proj",
45 "lm_head",
46];
47
48#[derive(Debug)]
49struct Mlp {
50 ffn_up: QLoraLinear,
51 ffn_down: QLoraLinear,
52 i_size: usize,
53}
54
55impl Mlp {
56 fn forward(
57 &self,
58 xs: &Tensor,
59 scalings: Option<Tensor>,
60 global_scaling_weight: f64,
61 is_scaling_pass: Option<f64>,
62 ) -> Result<Tensor> {
63 let up_states = self.ffn_up.lora_forward(
64 xs,
65 scalings.clone(),
66 global_scaling_weight,
67 is_scaling_pass,
68 )?;
69 let gate = up_states.narrow(D::Minus1, 0, self.i_size)?;
70 let up_states = up_states.narrow(D::Minus1, self.i_size, self.i_size)?;
71 let up_states = (up_states * gate.silu()?)?;
72 self.ffn_down
73 .lora_forward(&up_states, scalings, global_scaling_weight, is_scaling_pass)
74 }
75}
76
77fn rms_norm(w: QTensor, eps: f64) -> Result<RmsNorm> {
78 let w = w.dequantize(&w.device())?;
79 let rms = RmsNorm::from_w(w, eps)?;
80 Ok(rms)
81}
82
83struct LayerWeights {
84 attn_qkv: QLoraLinear,
85 attn_output: QLoraLinear,
86 attn_norm: RmsNorm,
87 ffn_norm: RmsNorm,
88 mlp: Mlp,
89 n_head: usize,
90 n_kv_head: usize,
91 head_dim: usize,
92 cos: Tensor,
93 sin: Tensor,
94 sliding_window: usize,
95 sdpa_params: SdpaParams,
96 dtype: DType,
97}
98
99impl LayerWeights {
100 fn apply_rotary_emb(&self, xs: &Tensor, seqlen_offsets: &[usize]) -> Result<Tensor> {
101 let (_b_sz, _h, seq_len, _n_embd) = xs.dims4()?;
102 let mut outputs = Vec::new();
103 for (i, offset) in seqlen_offsets.iter().enumerate() {
104 let cos = self.cos.narrow(0, *offset, seq_len)?;
105 let sin = self.sin.narrow(0, *offset, seq_len)?;
106 outputs.push(candle_nn::rotary_emb::rope(
107 &xs.i(i)?.unsqueeze(0)?.contiguous()?,
108 &cos,
109 &sin,
110 )?);
111 }
112 Tensor::cat(&outputs, 0)
113 }
114
115 #[allow(clippy::too_many_arguments)]
116 fn forward_attn(
117 &self,
118 x: &Tensor,
119 mask: Option<&Tensor>,
120 seqlen_offsets: &[usize],
121 kv_cache: &mut Option<(Tensor, Tensor)>,
122 scalings: Option<Tensor>,
123 global_scaling_weight: f64,
124 is_scaling_pass: Option<f64>,
125 flash_params: &FlashParams,
126 ) -> Result<Tensor> {
127 let (b_sz, seq_len, n_embd) = x.dims3()?;
128 let qkv = self
129 .attn_qkv
130 .lora_forward(x, scalings.clone(), global_scaling_weight, is_scaling_pass)?
131 .to_dtype(self.dtype)?;
132
133 let query_pos = self.n_head * self.head_dim;
134 let q = qkv.narrow(D::Minus1, 0, query_pos)?;
135 let k = qkv.narrow(D::Minus1, query_pos, self.n_kv_head * self.head_dim)?;
136 let v = qkv.narrow(
137 D::Minus1,
138 query_pos + self.n_kv_head * self.head_dim,
139 self.n_kv_head * self.head_dim,
140 )?;
141
142 let (q, k, v) = if seq_len != 1 {
143 let q = q
144 .reshape((b_sz, seq_len, self.n_head, self.head_dim))?
145 .transpose(1, 2)?;
146 let k = k
147 .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
148 .transpose(1, 2)?;
149 let v = v
150 .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
151 .transpose(1, 2)?;
152 (q, k, v)
153 } else {
154 let q = q.reshape((b_sz, self.n_head, seq_len, self.head_dim))?;
155 let k = k.reshape((b_sz, self.n_kv_head, seq_len, self.head_dim))?;
156 let v = v.reshape((b_sz, self.n_kv_head, seq_len, self.head_dim))?;
157 (q, k, v)
158 };
159
160 let q = self.apply_rotary_emb(&q, seqlen_offsets)?.contiguous()?;
161 let k = self.apply_rotary_emb(&k, seqlen_offsets)?;
162
163 let (k, v, attn_mask) = Cache::update_kv_cache_sliding_window(
164 kv_cache,
165 k,
166 v,
167 mask,
168 Some(self.sliding_window),
169 true,
170 )?;
171
172 let y = Sdpa.run_attention(
173 &q,
174 &k,
175 &v,
176 attn_mask.as_ref(),
177 Some(flash_params),
178 &self.sdpa_params,
179 )?;
180
181 let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
182 let y = self.attn_output.lora_forward(
183 &y.to_dtype(x.dtype())?,
184 scalings,
185 global_scaling_weight,
186 is_scaling_pass,
187 )?;
188 Ok(y)
189 }
190}
191
192pub struct ModelWeights {
193 tok_embeddings: Embedding,
194 layers: Vec<LayerWeights>,
195 output_norm: RmsNorm,
196 output: QLoraLinear,
197 mapper: Option<Box<dyn DeviceMapper + Send + Sync>>,
198 pub device: Device,
199 pub cache: EitherCache,
200 pub max_seq_len: usize,
201 xlora_classifier: Option<XLoraClassifier>,
202 dtype: DType,
203}
204
205fn precomput_freqs_cis(
206 head_dim: usize,
207 freq_base: f32,
208 device: &Device,
209 context_window: usize,
210 dtype: DType,
211) -> Result<(Tensor, Tensor)> {
212 let theta: Vec<_> = (0..head_dim)
213 .step_by(2)
214 .map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32))
215 .collect();
216 let theta = Tensor::new(theta.as_slice(), device)?;
217 let idx_theta = Tensor::arange(0, context_window as u32, device)?
218 .to_dtype(DType::F32)?
219 .reshape((context_window, 1))?
220 .matmul(&theta.reshape((1, theta.elem_count()))?)?;
221 let cos = idx_theta.cos()?.to_dtype(dtype)?;
222 let sin = idx_theta.sin()?.to_dtype(dtype)?;
223 Ok((cos, sin))
224}
225
226impl ModelConfig::FromAdapterGGUF for ModelWeights {
227 #[allow(clippy::too_many_arguments)]
228 fn from_gguf<R: std::io::Seek + std::io::Read>(
229 mut ct: Content<'_, R>,
230 device: &Device,
231 lora_config: &[((String, String), LoraConfig)],
232 vb: &ShardedVarBuilder,
233 ordering: &Ordering,
234 xlora_config: Option<XLoraConfig>,
235 mapper: Box<dyn DeviceMapper + Send + Sync>,
236 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
237 dtype: DType,
238 ) -> Result<Self> {
239 verify_sanity_adapters(ordering, &SUPPORTED_LAYERS)?;
240
241 let metadata = ContentMetadata {
243 path_prefix: "phi3",
244 metadata: ct.get_metadata(),
245 };
246 let PropsGGUF {
247 head_count,
248 head_count_kv,
249 block_count,
250 embedding_length,
251 i_size,
252 rope_dim,
253 rms_eps,
254 context_window,
255 } = PropsGGUF::try_from(metadata).or_else(|err| candle_core::bail!("{err}"))?;
256
257 let (cos, sin) = precomput_freqs_cis(rope_dim, 10_000., device, context_window, dtype)?;
258
259 let tok_embeddings = ct.tensor("token_embd.weight", device)?;
260 let tok_embeddings = tok_embeddings.dequantize(device)?;
261 let output_norm = rms_norm(ct.tensor("output_norm.weight", device)?, rms_eps)?;
262 let output = ct.tensor("output.weight", device)?;
263 let mut layers = Vec::with_capacity(block_count);
264
265 let mut count = 0;
266 for layer_idx in NiceProgressBar::<_, 'b'>(
267 0..block_count,
268 "Loading repeating layers",
269 &MultiProgress::new(),
270 ) {
271 let prefix = format!("blk.{layer_idx}");
272 let device = mapper.device_for(layer_idx, false).unwrap_or(device);
273 let ffn_up = ct.tensor(&format!("{prefix}.ffn_up.weight"), device)?;
274 let ffn_down = ct.tensor(&format!("{prefix}.ffn_down.weight"), device)?;
275 let cfg_up = get_lora_cfg(&ffn_up);
276 let cfg_down = get_lora_cfg(&ffn_down);
277 let mlp = Mlp {
278 ffn_up: QLoraLinear::new(
279 QMatMul::from_qtensor(ffn_up)?,
280 &cfg_up,
281 lora_config,
282 vb,
283 ordering,
284 format!("{prefix}.mlp.gate_up_proj"),
285 &mut count,
286 preload_adapters,
287 )?,
288 ffn_down: QLoraLinear::new(
289 QMatMul::from_qtensor(ffn_down)?,
290 &cfg_down,
291 lora_config,
292 vb,
293 ordering,
294 format!("{prefix}.mlp.down_proj"),
295 &mut count,
296 preload_adapters,
297 )?,
298 i_size,
299 };
300 let attn_norm = rms_norm(
301 ct.tensor(&format!("{prefix}.attn_norm.weight"), device)?,
302 rms_eps,
303 )?;
304 let ffn_norm = rms_norm(
305 ct.tensor(&format!("{prefix}.ffn_norm.weight"), device)?,
306 rms_eps,
307 )?;
308 let qkv = ct.tensor(&format!("{prefix}.attn_qkv.weight"), device)?;
309 let output = ct.tensor(&format!("{prefix}.attn_output.weight"), device)?;
310 let cfg_qkv = get_lora_cfg(&qkv);
311 let cfg_out = get_lora_cfg(&output);
312 let head_dim = embedding_length / head_count;
313 layers.push(LayerWeights {
314 attn_qkv: QLoraLinear::new(
315 QMatMul::from_qtensor(qkv)?,
316 &cfg_qkv,
317 lora_config,
318 vb,
319 ordering,
320 format!("{prefix}.self_attn.qkv_proj"),
321 &mut count,
322 preload_adapters,
323 )?,
324 attn_output: QLoraLinear::new(
325 QMatMul::from_qtensor(output)?,
326 &cfg_out,
327 lora_config,
328 vb,
329 ordering,
330 format!("{prefix}.self_attn.o_proj"),
331 &mut count,
332 preload_adapters,
333 )?,
334 attn_norm,
335 ffn_norm,
336 mlp,
337 n_head: head_count,
338 n_kv_head: head_count_kv,
339 head_dim: embedding_length / head_count,
340 cos: cos.to_device(device)?,
341 sin: sin.to_device(device)?,
342 sliding_window: context_window,
343 sdpa_params: SdpaParams {
344 n_kv_groups: head_count / head_count_kv,
345 softcap: None,
346 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
347 sliding_window: Some(context_window),
348 },
349 dtype,
350 })
351 }
352 if xlora_config.is_none() {
353 info!("Merging LoRA adapters.");
355 for layer in layers.iter_mut().tqdm() {
356 layer.attn_qkv.merge_weights()?;
357 layer.attn_output.merge_weights()?;
358 layer.mlp.ffn_down.merge_weights()?;
359 layer.mlp.ffn_up.merge_weights()?;
360 }
361 }
362 let output_cfg = get_lora_cfg(&output);
363 let output = QLoraLinear::new(
364 QMatMul::from_qtensor(output)?,
365 &output_cfg,
366 lora_config,
367 vb,
368 ordering,
369 "lm_head".to_string(),
370 &mut count,
371 preload_adapters,
372 )?;
373 if xlora_config.is_some() && output.is_lora() {
374 candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA.");
376 }
377 Ok(Self {
378 tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
379 layers,
380 output_norm,
381 output,
382 mapper: Some(mapper),
383 device: device.clone(),
384 cache: EitherCache::Full(Cache::new(block_count, true)),
385 max_seq_len: context_window,
386 xlora_classifier: xlora_config.map(|xlora_config| {
387 XLoraClassifier::new(xlora_config, count, lora_config.len(), vb.clone(), true)
388 .unwrap()
389 }),
390 dtype,
391 })
392 }
393}
394
395impl ModelWeights {
396 #[allow(clippy::too_many_arguments)]
397 pub fn inner_forward(
398 &self,
399 input_ids: &Tensor,
400 seqlen_offsets: &[usize],
401 scalings: Option<Tensor>,
402 is_full_pass: bool,
403 no_kv_cache: bool,
404 is_scaling_pass: Option<f64>,
405 flash_params: &FlashParams,
406 ) -> Result<Tensor> {
407 let mut xs = self.tok_embeddings.forward(input_ids)?;
408 let mut cache = if is_full_pass {
409 if no_kv_cache {
410 let mut new_cache = Vec::new();
411 for _ in 0..self.cache.full().xlora_lock().len() {
412 new_cache.push(None);
413 }
414
415 self.cache.full().xlora_lock().clone_from(&new_cache);
416 }
417 self.cache.full().xlora_lock()
418 } else {
419 self.cache.full().lock()
420 };
421 let mask = CausalMasker.make_sliding_window_causal_mask_matrix(
422 input_ids,
423 &*cache,
424 Some(self.max_seq_len),
425 self.dtype,
426 self.layers[0].n_head,
427 )?;
428 for (i, layer) in self.layers.iter().enumerate() {
429 if let Some(ref mapper) = self.mapper {
430 xs = mapper.map(xs, i)?;
431 }
432 let residual = &xs;
433 let ys = xs.apply(&layer.attn_norm)?;
434 let ys = layer.forward_attn(
435 &ys,
436 mask.as_ref()
437 .map(|m| m.to_device(xs.device()).unwrap())
438 .as_ref(),
439 seqlen_offsets,
440 &mut cache[i],
441 scalings.clone(),
442 self.xlora_classifier
443 .as_ref()
444 .map(|classifier| classifier.get_global_scaling_weight())
445 .unwrap_or(1.0),
446 is_scaling_pass,
447 flash_params,
448 )?;
449 let ys = (ys + residual)?;
450 let residual = &ys;
451 let ys = ys.apply(&layer.ffn_norm)?;
452 let ys = layer.mlp.forward(
453 &ys,
454 scalings.clone(),
455 self.xlora_classifier
456 .as_ref()
457 .map(|classifier| classifier.get_global_scaling_weight())
458 .unwrap_or(1.0),
459 is_scaling_pass,
460 )?;
461 xs = (ys + residual)?
462 }
463 let xs = xs.to_device(&self.device)?;
464 xs.apply(&self.output_norm)
465 }
466
467 #[allow(clippy::too_many_arguments)]
468 pub fn forward(
469 &self,
470 input_ids: &Tensor,
471 input_ids_full: &Tensor,
472 seqlen_offsets: &[usize],
473 seqlen_offsets_full: &[usize],
474 no_kv_cache: bool,
475 non_granular_state: &Option<NonGranularState>,
476 context_lens: Vec<(usize, usize)>,
477 flash_params: &FlashParams,
478 flash_params_full: &FlashParams,
479 ) -> Result<Tensor> {
480 if self.xlora_classifier.is_some() {
481 let scalings = self.get_scalings(
482 input_ids,
483 input_ids_full,
484 seqlen_offsets,
485 seqlen_offsets_full,
486 no_kv_cache,
487 non_granular_state,
488 &vec![usize::MAX; context_lens.len()],
489 flash_params,
490 flash_params_full,
491 )?;
492
493 if no_kv_cache {
494 extract_logits(
495 &self.output.lora_forward(
496 &self
497 .inner_forward(
498 input_ids_full,
499 seqlen_offsets_full,
500 Some(scalings),
501 true,
502 no_kv_cache,
503 None,
504 flash_params_full,
505 )?
506 .contiguous()?,
507 None,
508 1.0,
509 None,
510 )?,
511 context_lens,
512 )
513 } else {
514 extract_logits(
516 &self.output.lora_forward(
517 &self
518 .inner_forward(
519 input_ids,
520 seqlen_offsets,
521 Some(scalings),
522 true,
523 no_kv_cache,
524 None,
525 flash_params,
526 )?
527 .contiguous()?,
528 None,
529 1.0,
530 None,
531 )?,
532 context_lens,
533 )
534 }
535 } else {
536 extract_logits(
537 &self.output.lora_forward(
538 &self
539 .inner_forward(
540 input_ids,
541 seqlen_offsets,
542 None,
543 false,
544 no_kv_cache,
545 None,
546 flash_params,
547 )?
548 .contiguous()?,
549 None,
550 1.0,
551 None,
552 )?,
553 context_lens,
554 )
555 }
556 }
557}
558
559impl ScalingsMaker for ModelWeights {
560 fn dtype(&self) -> DType {
561 DType::F32 }
563 fn get_cache(&self) -> &EitherCache {
564 &self.cache
565 }
566 fn get_classifier(&self) -> &XLoraClassifier {
567 self.xlora_classifier.as_ref().unwrap()
568 }
569 fn forward(
570 &self,
571 input_ids: &Tensor,
572 seqlen_offsets: &[usize],
573 scalings: Tensor,
574 is_full_pass: bool,
575 no_kv_cache: bool,
576 is_scaling_pass: Option<f64>,
577 _context_lens: &[usize],
578 flash_params: &FlashParams,
579 ) -> Result<Tensor> {
580 self.inner_forward(
581 input_ids,
582 seqlen_offsets,
583 Some(scalings),
584 is_full_pass,
585 no_kv_cache,
586 is_scaling_pass,
587 flash_params,
588 )
589 }
590}