1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::collections::HashMap;
4
5use crate::attention::SdpaParams;
6use crate::device_map::DeviceMapper;
7use crate::gguf::Content;
8use crate::layers::CausalMasker;
9use crate::layers::RmsNorm;
10use crate::layers::Sdpa;
11use crate::lora::get_lora_cfg;
12use crate::lora::LinearLayerLike;
13use crate::lora::LoraConfig;
14use crate::lora::Merge;
15use crate::lora::Ordering;
16use crate::lora::QLoraLinear;
17use crate::pipeline::extract_logits;
18use crate::pipeline::text_models_inputs_processor::FlashParams;
19use crate::pipeline::EitherCache;
20use crate::utils::progress::NiceProgressBar;
21use candle_core::quantized::QMatMul;
22use candle_core::quantized::QTensor;
23use candle_core::{DType, Device, IndexOp, Module, Result, Tensor, D};
24use candle_nn::Embedding;
25use indicatif::MultiProgress;
26use mistralrs_quant::ShardedVarBuilder;
27use tqdm::Iter;
28use tracing::info;
29
30use super::classifier::XLoraClassifier;
31use super::verify_sanity_adapters;
32use super::Cache;
33use super::NonGranularState;
34use super::ScalingsMaker;
35use super::XLoraConfig;
36use crate::models::quantized_phi3::PropsGGUF;
37use crate::utils::gguf_metadata::ContentMetadata;
38use crate::utils::model_config as ModelConfig;
39
40const SUPPORTED_LAYERS: [&str; 5] = [
41 "self_attn.qkv_proj",
42 "self_attn.o_proj",
43 "mlp.gate_up_proj",
44 "mlp.down_proj",
45 "lm_head",
46];
47
48#[derive(Debug)]
49struct Mlp {
50 ffn_up: QLoraLinear,
51 ffn_down: QLoraLinear,
52 i_size: usize,
53}
54
55impl Mlp {
56 fn forward(
57 &self,
58 xs: &Tensor,
59 scalings: Option<Tensor>,
60 global_scaling_weight: f64,
61 is_scaling_pass: Option<f64>,
62 ) -> Result<Tensor> {
63 let up_states = self.ffn_up.lora_forward(
64 xs,
65 scalings.clone(),
66 global_scaling_weight,
67 is_scaling_pass,
68 )?;
69 let gate = up_states.narrow(D::Minus1, 0, self.i_size)?;
70 let up_states = up_states.narrow(D::Minus1, self.i_size, self.i_size)?;
71 let up_states = (up_states * gate.silu()?)?;
72 self.ffn_down
73 .lora_forward(&up_states, scalings, global_scaling_weight, is_scaling_pass)
74 }
75}
76
77fn rms_norm(w: QTensor, eps: f64) -> Result<RmsNorm> {
78 let w = w.dequantize(&w.device())?;
79 let rms = RmsNorm::from_w(w, eps)?;
80 Ok(rms)
81}
82
83struct LayerWeights {
84 attn_qkv: QLoraLinear,
85 attn_output: QLoraLinear,
86 attn_norm: RmsNorm,
87 ffn_norm: RmsNorm,
88 mlp: Mlp,
89 n_head: usize,
90 n_kv_head: usize,
91 head_dim: usize,
92 cos: Tensor,
93 sin: Tensor,
94 sliding_window: usize,
95 sdpa_params: SdpaParams,
96 dtype: DType,
97}
98
99impl LayerWeights {
100 fn apply_rotary_emb(&self, xs: &Tensor, seqlen_offsets: &[usize]) -> Result<Tensor> {
101 let (_b_sz, _h, seq_len, _n_embd) = xs.dims4()?;
102 let mut outputs = Vec::new();
103 for (i, offset) in seqlen_offsets.iter().enumerate() {
104 let cos = self.cos.narrow(0, *offset, seq_len)?;
105 let sin = self.sin.narrow(0, *offset, seq_len)?;
106 outputs.push(candle_nn::rotary_emb::rope(
107 &xs.i(i)?.unsqueeze(0)?.contiguous()?,
108 &cos,
109 &sin,
110 )?);
111 }
112 Tensor::cat(&outputs, 0)
113 }
114
115 #[allow(clippy::too_many_arguments)]
116 fn forward_attn(
117 &self,
118 x: &Tensor,
119 mask: Option<&Tensor>,
120 seqlen_offsets: &[usize],
121 kv_cache: &mut Option<(Tensor, Tensor)>,
122 scalings: Option<Tensor>,
123 global_scaling_weight: f64,
124 is_scaling_pass: Option<f64>,
125 flash_params: &FlashParams,
126 ) -> Result<Tensor> {
127 let (b_sz, seq_len, n_embd) = x.dims3()?;
128 let qkv = self
129 .attn_qkv
130 .lora_forward(x, scalings.clone(), global_scaling_weight, is_scaling_pass)?
131 .to_dtype(self.dtype)?;
132
133 let query_pos = self.n_head * self.head_dim;
134 let q = qkv.narrow(D::Minus1, 0, query_pos)?;
135 let k = qkv.narrow(D::Minus1, query_pos, self.n_kv_head * self.head_dim)?;
136 let v = qkv.narrow(
137 D::Minus1,
138 query_pos + self.n_kv_head * self.head_dim,
139 self.n_kv_head * self.head_dim,
140 )?;
141
142 let (q, k, v) = if seq_len != 1 {
143 let q = q
144 .reshape((b_sz, seq_len, self.n_head, self.head_dim))?
145 .transpose(1, 2)?;
146 let k = k
147 .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
148 .transpose(1, 2)?;
149 let v = v
150 .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
151 .transpose(1, 2)?;
152 (q, k, v)
153 } else {
154 let q = q.reshape((b_sz, self.n_head, seq_len, self.head_dim))?;
155 let k = k.reshape((b_sz, self.n_kv_head, seq_len, self.head_dim))?;
156 let v = v.reshape((b_sz, self.n_kv_head, seq_len, self.head_dim))?;
157 (q, k, v)
158 };
159
160 let q = self.apply_rotary_emb(&q, seqlen_offsets)?.contiguous()?;
161 let k = self.apply_rotary_emb(&k, seqlen_offsets)?;
162
163 let (k, v, attn_mask) =
164 Cache::update_kv_cache_sliding_window(kv_cache, k, v, mask, Some(self.sliding_window))?;
165
166 let y = Sdpa.run_attention(
167 &q,
168 &k,
169 &v,
170 attn_mask.as_ref(),
171 Some(flash_params),
172 &self.sdpa_params,
173 )?;
174
175 let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
176 let y = self.attn_output.lora_forward(
177 &y.to_dtype(x.dtype())?,
178 scalings,
179 global_scaling_weight,
180 is_scaling_pass,
181 )?;
182 Ok(y)
183 }
184}
185
186pub struct ModelWeights {
187 tok_embeddings: Embedding,
188 layers: Vec<LayerWeights>,
189 output_norm: RmsNorm,
190 output: QLoraLinear,
191 mapper: Option<Box<dyn DeviceMapper + Send + Sync>>,
192 pub device: Device,
193 pub cache: EitherCache,
194 pub max_seq_len: usize,
195 xlora_classifier: Option<XLoraClassifier>,
196 dtype: DType,
197}
198
199fn precomput_freqs_cis(
200 head_dim: usize,
201 freq_base: f32,
202 device: &Device,
203 context_window: usize,
204 dtype: DType,
205) -> Result<(Tensor, Tensor)> {
206 let theta: Vec<_> = (0..head_dim)
207 .step_by(2)
208 .map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32))
209 .collect();
210 let theta = Tensor::new(theta.as_slice(), device)?;
211 let idx_theta = Tensor::arange(0, context_window as u32, device)?
212 .to_dtype(DType::F32)?
213 .reshape((context_window, 1))?
214 .matmul(&theta.reshape((1, theta.elem_count()))?)?;
215 let cos = idx_theta.cos()?.to_dtype(dtype)?;
216 let sin = idx_theta.sin()?.to_dtype(dtype)?;
217 Ok((cos, sin))
218}
219
220impl ModelConfig::FromAdapterGGUF for ModelWeights {
221 #[allow(clippy::too_many_arguments)]
222 fn from_gguf<R: std::io::Seek + std::io::Read>(
223 mut ct: Content<'_, R>,
224 device: &Device,
225 lora_config: &[((String, String), LoraConfig)],
226 vb: &ShardedVarBuilder,
227 ordering: &Ordering,
228 xlora_config: Option<XLoraConfig>,
229 mapper: Box<dyn DeviceMapper + Send + Sync>,
230 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
231 dtype: DType,
232 ) -> Result<Self> {
233 verify_sanity_adapters(ordering, &SUPPORTED_LAYERS)?;
234
235 let metadata = ContentMetadata {
237 path_prefix: "phi3",
238 metadata: ct.get_metadata(),
239 };
240 let PropsGGUF {
241 head_count,
242 head_count_kv,
243 block_count,
244 embedding_length,
245 i_size,
246 rope_dim,
247 rms_eps,
248 context_window,
249 } = PropsGGUF::try_from(metadata).or_else(|err| candle_core::bail!("{err}"))?;
250
251 let (cos, sin) = precomput_freqs_cis(rope_dim, 10_000., device, context_window, dtype)?;
252
253 let tok_embeddings = ct.tensor("token_embd.weight", device)?;
254 let tok_embeddings = tok_embeddings.dequantize(device)?;
255 let output_norm = rms_norm(ct.tensor("output_norm.weight", device)?, rms_eps)?;
256 let output = ct.tensor("output.weight", device)?;
257 let mut layers = Vec::with_capacity(block_count);
258
259 let mut count = 0;
260 for layer_idx in NiceProgressBar::<_, 'b'>(
261 0..block_count,
262 "Loading repeating layers",
263 &MultiProgress::new(),
264 ) {
265 let prefix = format!("blk.{layer_idx}");
266 let device = mapper.device_for(layer_idx, false).unwrap_or(device);
267 let ffn_up = ct.tensor(&format!("{prefix}.ffn_up.weight"), device)?;
268 let ffn_down = ct.tensor(&format!("{prefix}.ffn_down.weight"), device)?;
269 let cfg_up = get_lora_cfg(&ffn_up);
270 let cfg_down = get_lora_cfg(&ffn_down);
271 let mlp = Mlp {
272 ffn_up: QLoraLinear::new(
273 QMatMul::from_qtensor(ffn_up)?,
274 &cfg_up,
275 lora_config,
276 vb,
277 ordering,
278 format!("{prefix}.mlp.gate_up_proj"),
279 &mut count,
280 preload_adapters,
281 )?,
282 ffn_down: QLoraLinear::new(
283 QMatMul::from_qtensor(ffn_down)?,
284 &cfg_down,
285 lora_config,
286 vb,
287 ordering,
288 format!("{prefix}.mlp.down_proj"),
289 &mut count,
290 preload_adapters,
291 )?,
292 i_size,
293 };
294 let attn_norm = rms_norm(
295 ct.tensor(&format!("{prefix}.attn_norm.weight"), device)?,
296 rms_eps,
297 )?;
298 let ffn_norm = rms_norm(
299 ct.tensor(&format!("{prefix}.ffn_norm.weight"), device)?,
300 rms_eps,
301 )?;
302 let qkv = ct.tensor(&format!("{prefix}.attn_qkv.weight"), device)?;
303 let output = ct.tensor(&format!("{prefix}.attn_output.weight"), device)?;
304 let cfg_qkv = get_lora_cfg(&qkv);
305 let cfg_out = get_lora_cfg(&output);
306 let head_dim = embedding_length / head_count;
307 layers.push(LayerWeights {
308 attn_qkv: QLoraLinear::new(
309 QMatMul::from_qtensor(qkv)?,
310 &cfg_qkv,
311 lora_config,
312 vb,
313 ordering,
314 format!("{prefix}.self_attn.qkv_proj"),
315 &mut count,
316 preload_adapters,
317 )?,
318 attn_output: QLoraLinear::new(
319 QMatMul::from_qtensor(output)?,
320 &cfg_out,
321 lora_config,
322 vb,
323 ordering,
324 format!("{prefix}.self_attn.o_proj"),
325 &mut count,
326 preload_adapters,
327 )?,
328 attn_norm,
329 ffn_norm,
330 mlp,
331 n_head: head_count,
332 n_kv_head: head_count_kv,
333 head_dim: embedding_length / head_count,
334 cos: cos.to_device(device)?,
335 sin: sin.to_device(device)?,
336 sliding_window: context_window,
337 sdpa_params: SdpaParams {
338 n_kv_groups: head_count / head_count_kv,
339 softcap: None,
340 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
341 sliding_window: Some(context_window),
342 },
343 dtype,
344 })
345 }
346 if xlora_config.is_none() {
347 info!("Merging LoRA adapters.");
349 for layer in layers.iter_mut().tqdm() {
350 layer.attn_qkv.merge_weights()?;
351 layer.attn_output.merge_weights()?;
352 layer.mlp.ffn_down.merge_weights()?;
353 layer.mlp.ffn_up.merge_weights()?;
354 }
355 }
356 let output_cfg = get_lora_cfg(&output);
357 let output = QLoraLinear::new(
358 QMatMul::from_qtensor(output)?,
359 &output_cfg,
360 lora_config,
361 vb,
362 ordering,
363 "lm_head".to_string(),
364 &mut count,
365 preload_adapters,
366 )?;
367 if xlora_config.is_some() && output.is_lora() {
368 candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA.");
370 }
371 Ok(Self {
372 tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
373 layers,
374 output_norm,
375 output,
376 mapper: Some(mapper),
377 device: device.clone(),
378 cache: EitherCache::Full(Cache::new(block_count, true)),
379 max_seq_len: context_window,
380 xlora_classifier: xlora_config.map(|xlora_config| {
381 XLoraClassifier::new(xlora_config, count, lora_config.len(), vb.clone(), true)
382 .unwrap()
383 }),
384 dtype,
385 })
386 }
387}
388
389impl ModelWeights {
390 #[allow(clippy::too_many_arguments)]
391 pub fn inner_forward(
392 &self,
393 input_ids: &Tensor,
394 seqlen_offsets: &[usize],
395 scalings: Option<Tensor>,
396 is_full_pass: bool,
397 no_kv_cache: bool,
398 is_scaling_pass: Option<f64>,
399 flash_params: &FlashParams,
400 ) -> Result<Tensor> {
401 let mut xs = self.tok_embeddings.forward(input_ids)?;
402 let mut cache = if is_full_pass {
403 if no_kv_cache {
404 let mut new_cache = Vec::new();
405 for _ in 0..self.cache.full().xlora_lock().len() {
406 new_cache.push(None);
407 }
408
409 self.cache.full().xlora_lock().clone_from(&new_cache);
410 }
411 self.cache.full().xlora_lock()
412 } else {
413 self.cache.full().lock()
414 };
415 let mask = CausalMasker.make_sliding_window_causal_mask_matrix(
416 input_ids,
417 &*cache,
418 Some(self.max_seq_len),
419 self.dtype,
420 self.layers[0].n_head,
421 )?;
422 for (i, layer) in self.layers.iter().enumerate() {
423 if let Some(ref mapper) = self.mapper {
424 xs = mapper.map(xs, i)?;
425 }
426 let residual = &xs;
427 let ys = xs.apply(&layer.attn_norm)?;
428 let ys = layer.forward_attn(
429 &ys,
430 mask.as_ref()
431 .map(|m| m.to_device(xs.device()).unwrap())
432 .as_ref(),
433 seqlen_offsets,
434 &mut cache[i],
435 scalings.clone(),
436 self.xlora_classifier
437 .as_ref()
438 .map(|classifier| classifier.get_global_scaling_weight())
439 .unwrap_or(1.0),
440 is_scaling_pass,
441 flash_params,
442 )?;
443 let ys = (ys + residual)?;
444 let residual = &ys;
445 let ys = ys.apply(&layer.ffn_norm)?;
446 let ys = layer.mlp.forward(
447 &ys,
448 scalings.clone(),
449 self.xlora_classifier
450 .as_ref()
451 .map(|classifier| classifier.get_global_scaling_weight())
452 .unwrap_or(1.0),
453 is_scaling_pass,
454 )?;
455 xs = (ys + residual)?
456 }
457 let xs = xs.to_device(&self.device)?;
458 xs.apply(&self.output_norm)
459 }
460
461 #[allow(clippy::too_many_arguments)]
462 pub fn forward(
463 &self,
464 input_ids: &Tensor,
465 input_ids_full: &Tensor,
466 seqlen_offsets: &[usize],
467 seqlen_offsets_full: &[usize],
468 no_kv_cache: bool,
469 non_granular_state: &Option<NonGranularState>,
470 context_lens: Vec<(usize, usize)>,
471 flash_params: &FlashParams,
472 flash_params_full: &FlashParams,
473 ) -> Result<Tensor> {
474 if self.xlora_classifier.is_some() {
475 let scalings = self.get_scalings(
476 input_ids,
477 input_ids_full,
478 seqlen_offsets,
479 seqlen_offsets_full,
480 no_kv_cache,
481 non_granular_state,
482 &vec![usize::MAX; context_lens.len()],
483 flash_params,
484 flash_params_full,
485 )?;
486
487 if no_kv_cache {
488 extract_logits(
489 &self.output.lora_forward(
490 &self
491 .inner_forward(
492 input_ids_full,
493 seqlen_offsets_full,
494 Some(scalings),
495 true,
496 no_kv_cache,
497 None,
498 flash_params_full,
499 )?
500 .contiguous()?,
501 None,
502 1.0,
503 None,
504 )?,
505 context_lens,
506 )
507 } else {
508 extract_logits(
510 &self.output.lora_forward(
511 &self
512 .inner_forward(
513 input_ids,
514 seqlen_offsets,
515 Some(scalings),
516 true,
517 no_kv_cache,
518 None,
519 flash_params,
520 )?
521 .contiguous()?,
522 None,
523 1.0,
524 None,
525 )?,
526 context_lens,
527 )
528 }
529 } else {
530 extract_logits(
531 &self.output.lora_forward(
532 &self
533 .inner_forward(
534 input_ids,
535 seqlen_offsets,
536 None,
537 false,
538 no_kv_cache,
539 None,
540 flash_params,
541 )?
542 .contiguous()?,
543 None,
544 1.0,
545 None,
546 )?,
547 context_lens,
548 )
549 }
550 }
551}
552
553impl ScalingsMaker for ModelWeights {
554 fn dtype(&self) -> DType {
555 DType::F32 }
557 fn get_cache(&self) -> &EitherCache {
558 &self.cache
559 }
560 fn get_classifier(&self) -> &XLoraClassifier {
561 self.xlora_classifier.as_ref().unwrap()
562 }
563 fn forward(
564 &self,
565 input_ids: &Tensor,
566 seqlen_offsets: &[usize],
567 scalings: Tensor,
568 is_full_pass: bool,
569 no_kv_cache: bool,
570 is_scaling_pass: Option<f64>,
571 _context_lens: &[usize],
572 flash_params: &FlashParams,
573 ) -> Result<Tensor> {
574 self.inner_forward(
575 input_ids,
576 seqlen_offsets,
577 Some(scalings),
578 is_full_pass,
579 no_kv_cache,
580 is_scaling_pass,
581 flash_params,
582 )
583 }
584}