1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::collections::HashMap;
4
5use crate::attention::SdpaParams;
6use crate::device_map::DeviceMapper;
7use crate::gguf::Content;
8use crate::layers::CausalMasker;
9use crate::layers::RmsNorm;
10use crate::layers::Sdpa;
11use crate::lora::get_lora_cfg;
12use crate::lora::LinearLayerLike;
13use crate::lora::LoraConfig;
14use crate::lora::Merge;
15use crate::lora::Ordering;
16use crate::lora::QLoraLinear;
17use crate::pipeline::extract_logits;
18use crate::pipeline::text_models_inputs_processor::FlashParams;
19use crate::pipeline::EitherCache;
20use crate::utils::progress::{new_multi_progress, NiceProgressBar};
21use candle_core::quantized::QMatMul;
22use candle_core::quantized::QTensor;
23use candle_core::{DType, Device, IndexOp, Module, Result, Tensor, D};
24use candle_nn::Embedding;
25use mistralrs_quant::ShardedVarBuilder;
26use tqdm::Iter;
27use tracing::info;
28
29use super::classifier::XLoraClassifier;
30use super::verify_sanity_adapters;
31use super::Cache;
32use super::NonGranularState;
33use super::ScalingsMaker;
34use super::XLoraConfig;
35use crate::models::quantized_phi3::PropsGGUF;
36use crate::utils::gguf_metadata::ContentMetadata;
37use crate::utils::model_config as ModelConfig;
38
39const SUPPORTED_LAYERS: [&str; 5] = [
40 "self_attn.qkv_proj",
41 "self_attn.o_proj",
42 "mlp.gate_up_proj",
43 "mlp.down_proj",
44 "lm_head",
45];
46
47#[derive(Debug)]
48struct Mlp {
49 ffn_up: QLoraLinear,
50 ffn_down: QLoraLinear,
51 i_size: usize,
52}
53
54impl Mlp {
55 fn forward(
56 &self,
57 xs: &Tensor,
58 scalings: Option<Tensor>,
59 global_scaling_weight: f64,
60 is_scaling_pass: Option<f64>,
61 ) -> Result<Tensor> {
62 let up_states = self.ffn_up.lora_forward(
63 xs,
64 scalings.clone(),
65 global_scaling_weight,
66 is_scaling_pass,
67 )?;
68 let gate = up_states.narrow(D::Minus1, 0, self.i_size)?;
69 let up_states = up_states.narrow(D::Minus1, self.i_size, self.i_size)?;
70 let up_states = (up_states * gate.silu()?)?;
71 self.ffn_down
72 .lora_forward(&up_states, scalings, global_scaling_weight, is_scaling_pass)
73 }
74}
75
76fn rms_norm(w: QTensor, eps: f64) -> Result<RmsNorm> {
77 let w = w.dequantize(&w.device())?;
78 let rms = RmsNorm::from_w(w, eps)?;
79 Ok(rms)
80}
81
82struct LayerWeights {
83 attn_qkv: QLoraLinear,
84 attn_output: QLoraLinear,
85 attn_norm: RmsNorm,
86 ffn_norm: RmsNorm,
87 mlp: Mlp,
88 n_head: usize,
89 n_kv_head: usize,
90 head_dim: usize,
91 cos: Tensor,
92 sin: Tensor,
93 sliding_window: usize,
94 sdpa_params: SdpaParams,
95 dtype: DType,
96}
97
98impl LayerWeights {
99 fn apply_rotary_emb(&self, xs: &Tensor, seqlen_offsets: &[usize]) -> Result<Tensor> {
100 let (_b_sz, _h, seq_len, _n_embd) = xs.dims4()?;
101 let mut outputs = Vec::new();
102 for (i, offset) in seqlen_offsets.iter().enumerate() {
103 let cos = self.cos.narrow(0, *offset, seq_len)?;
104 let sin = self.sin.narrow(0, *offset, seq_len)?;
105 outputs.push(candle_nn::rotary_emb::rope(
106 &xs.i(i)?.unsqueeze(0)?.contiguous()?,
107 &cos,
108 &sin,
109 )?);
110 }
111 Tensor::cat(&outputs, 0)
112 }
113
114 #[allow(clippy::too_many_arguments)]
115 fn forward_attn(
116 &self,
117 x: &Tensor,
118 mask: Option<&Tensor>,
119 seqlen_offsets: &[usize],
120 kv_cache: &mut Option<(Tensor, Tensor)>,
121 scalings: Option<Tensor>,
122 global_scaling_weight: f64,
123 is_scaling_pass: Option<f64>,
124 flash_params: &FlashParams,
125 ) -> Result<Tensor> {
126 let (b_sz, seq_len, n_embd) = x.dims3()?;
127 let qkv = self
128 .attn_qkv
129 .lora_forward(x, scalings.clone(), global_scaling_weight, is_scaling_pass)?
130 .to_dtype(self.dtype)?;
131
132 let query_pos = self.n_head * self.head_dim;
133 let q = qkv.narrow(D::Minus1, 0, query_pos)?;
134 let k = qkv.narrow(D::Minus1, query_pos, self.n_kv_head * self.head_dim)?;
135 let v = qkv.narrow(
136 D::Minus1,
137 query_pos + self.n_kv_head * self.head_dim,
138 self.n_kv_head * self.head_dim,
139 )?;
140
141 let (q, k, v) = if seq_len != 1 {
142 let q = q
143 .reshape((b_sz, seq_len, self.n_head, self.head_dim))?
144 .transpose(1, 2)?;
145 let k = k
146 .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
147 .transpose(1, 2)?;
148 let v = v
149 .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
150 .transpose(1, 2)?;
151 (q, k, v)
152 } else {
153 let q = q.reshape((b_sz, self.n_head, seq_len, self.head_dim))?;
154 let k = k.reshape((b_sz, self.n_kv_head, seq_len, self.head_dim))?;
155 let v = v.reshape((b_sz, self.n_kv_head, seq_len, self.head_dim))?;
156 (q, k, v)
157 };
158
159 let q = self.apply_rotary_emb(&q, seqlen_offsets)?.contiguous()?;
160 let k = self.apply_rotary_emb(&k, seqlen_offsets)?;
161
162 let (k, v, attn_mask) =
163 Cache::update_kv_cache_sliding_window(kv_cache, k, v, mask, Some(self.sliding_window))?;
164
165 let y = Sdpa.run_attention(
166 &q,
167 &k,
168 &v,
169 attn_mask.as_ref(),
170 Some(flash_params),
171 &self.sdpa_params,
172 )?;
173
174 let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
175 let y = self.attn_output.lora_forward(
176 &y.to_dtype(x.dtype())?,
177 scalings,
178 global_scaling_weight,
179 is_scaling_pass,
180 )?;
181 Ok(y)
182 }
183}
184
185pub struct ModelWeights {
186 tok_embeddings: Embedding,
187 layers: Vec<LayerWeights>,
188 output_norm: RmsNorm,
189 output: QLoraLinear,
190 mapper: Option<Box<dyn DeviceMapper + Send + Sync>>,
191 pub device: Device,
192 pub cache: EitherCache,
193 pub max_seq_len: usize,
194 xlora_classifier: Option<XLoraClassifier>,
195 dtype: DType,
196}
197
198fn precomput_freqs_cis(
199 head_dim: usize,
200 freq_base: f32,
201 device: &Device,
202 context_window: usize,
203 dtype: DType,
204) -> Result<(Tensor, Tensor)> {
205 let theta: Vec<_> = (0..head_dim)
206 .step_by(2)
207 .map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32))
208 .collect();
209 let theta = Tensor::new(theta.as_slice(), device)?;
210 let idx_theta = Tensor::arange(0, context_window as u32, device)?
211 .to_dtype(DType::F32)?
212 .reshape((context_window, 1))?
213 .matmul(&theta.reshape((1, theta.elem_count()))?)?;
214 let cos = idx_theta.cos()?.to_dtype(dtype)?;
215 let sin = idx_theta.sin()?.to_dtype(dtype)?;
216 Ok((cos, sin))
217}
218
219impl ModelConfig::FromAdapterGGUF for ModelWeights {
220 #[allow(clippy::too_many_arguments)]
221 fn from_gguf<R: std::io::Seek + std::io::Read>(
222 mut ct: Content<'_, R>,
223 device: &Device,
224 lora_config: &[((String, String), LoraConfig)],
225 vb: &ShardedVarBuilder,
226 ordering: &Ordering,
227 xlora_config: Option<XLoraConfig>,
228 mapper: Box<dyn DeviceMapper + Send + Sync>,
229 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
230 dtype: DType,
231 ) -> Result<Self> {
232 verify_sanity_adapters(ordering, &SUPPORTED_LAYERS)?;
233
234 let metadata = ContentMetadata {
236 path_prefix: "phi3",
237 metadata: ct.get_metadata(),
238 };
239 let PropsGGUF {
240 head_count,
241 head_count_kv,
242 block_count,
243 embedding_length,
244 i_size,
245 rope_dim,
246 rms_eps,
247 context_window,
248 } = PropsGGUF::try_from(metadata).or_else(|err| candle_core::bail!("{err}"))?;
249
250 let (cos, sin) = precomput_freqs_cis(rope_dim, 10_000., device, context_window, dtype)?;
251
252 let tok_embeddings = ct.tensor("token_embd.weight", device)?;
253 let tok_embeddings = tok_embeddings.dequantize(device)?;
254 let output_norm = rms_norm(ct.tensor("output_norm.weight", device)?, rms_eps)?;
255 let output = ct.tensor("output.weight", device)?;
256 let mut layers = Vec::with_capacity(block_count);
257
258 let mut count = 0;
259 for layer_idx in NiceProgressBar::<_, 'b'>(
260 0..block_count,
261 "Loading repeating layers",
262 &new_multi_progress(),
263 ) {
264 let prefix = format!("blk.{layer_idx}");
265 let device = mapper.device_for(layer_idx, false).unwrap_or(device);
266 let ffn_up = ct.tensor(&format!("{prefix}.ffn_up.weight"), device)?;
267 let ffn_down = ct.tensor(&format!("{prefix}.ffn_down.weight"), device)?;
268 let cfg_up = get_lora_cfg(&ffn_up);
269 let cfg_down = get_lora_cfg(&ffn_down);
270 let mlp = Mlp {
271 ffn_up: QLoraLinear::new(
272 QMatMul::from_qtensor(ffn_up)?,
273 &cfg_up,
274 lora_config,
275 vb,
276 ordering,
277 format!("{prefix}.mlp.gate_up_proj"),
278 &mut count,
279 preload_adapters,
280 )?,
281 ffn_down: QLoraLinear::new(
282 QMatMul::from_qtensor(ffn_down)?,
283 &cfg_down,
284 lora_config,
285 vb,
286 ordering,
287 format!("{prefix}.mlp.down_proj"),
288 &mut count,
289 preload_adapters,
290 )?,
291 i_size,
292 };
293 let attn_norm = rms_norm(
294 ct.tensor(&format!("{prefix}.attn_norm.weight"), device)?,
295 rms_eps,
296 )?;
297 let ffn_norm = rms_norm(
298 ct.tensor(&format!("{prefix}.ffn_norm.weight"), device)?,
299 rms_eps,
300 )?;
301 let qkv = ct.tensor(&format!("{prefix}.attn_qkv.weight"), device)?;
302 let output = ct.tensor(&format!("{prefix}.attn_output.weight"), device)?;
303 let cfg_qkv = get_lora_cfg(&qkv);
304 let cfg_out = get_lora_cfg(&output);
305 let head_dim = embedding_length / head_count;
306 layers.push(LayerWeights {
307 attn_qkv: QLoraLinear::new(
308 QMatMul::from_qtensor(qkv)?,
309 &cfg_qkv,
310 lora_config,
311 vb,
312 ordering,
313 format!("{prefix}.self_attn.qkv_proj"),
314 &mut count,
315 preload_adapters,
316 )?,
317 attn_output: QLoraLinear::new(
318 QMatMul::from_qtensor(output)?,
319 &cfg_out,
320 lora_config,
321 vb,
322 ordering,
323 format!("{prefix}.self_attn.o_proj"),
324 &mut count,
325 preload_adapters,
326 )?,
327 attn_norm,
328 ffn_norm,
329 mlp,
330 n_head: head_count,
331 n_kv_head: head_count_kv,
332 head_dim: embedding_length / head_count,
333 cos: cos.to_device(device)?,
334 sin: sin.to_device(device)?,
335 sliding_window: context_window,
336 sdpa_params: SdpaParams {
337 n_kv_groups: head_count / head_count_kv,
338 softcap: None,
339 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
340 sliding_window: Some(context_window),
341 },
342 dtype,
343 })
344 }
345 if xlora_config.is_none() {
346 info!("Merging LoRA adapters.");
348 for layer in layers.iter_mut().tqdm() {
349 layer.attn_qkv.merge_weights()?;
350 layer.attn_output.merge_weights()?;
351 layer.mlp.ffn_down.merge_weights()?;
352 layer.mlp.ffn_up.merge_weights()?;
353 }
354 }
355 let output_cfg = get_lora_cfg(&output);
356 let output = QLoraLinear::new(
357 QMatMul::from_qtensor(output)?,
358 &output_cfg,
359 lora_config,
360 vb,
361 ordering,
362 "lm_head".to_string(),
363 &mut count,
364 preload_adapters,
365 )?;
366 if xlora_config.is_some() && output.is_lora() {
367 candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA.");
369 }
370 Ok(Self {
371 tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
372 layers,
373 output_norm,
374 output,
375 mapper: Some(mapper),
376 device: device.clone(),
377 cache: EitherCache::Full(Cache::new(block_count, true)),
378 max_seq_len: context_window,
379 xlora_classifier: xlora_config.map(|xlora_config| {
380 XLoraClassifier::new(xlora_config, count, lora_config.len(), vb.clone(), true)
381 .unwrap()
382 }),
383 dtype,
384 })
385 }
386}
387
388impl ModelWeights {
389 #[allow(clippy::too_many_arguments)]
390 pub fn inner_forward(
391 &self,
392 input_ids: &Tensor,
393 seqlen_offsets: &[usize],
394 scalings: Option<Tensor>,
395 is_full_pass: bool,
396 no_kv_cache: bool,
397 is_scaling_pass: Option<f64>,
398 flash_params: &FlashParams,
399 ) -> Result<Tensor> {
400 let mut xs = self.tok_embeddings.forward(input_ids)?;
401 let mut cache = if is_full_pass {
402 if no_kv_cache {
403 let mut new_cache = Vec::new();
404 for _ in 0..self.cache.full().xlora_lock().len() {
405 new_cache.push(None);
406 }
407
408 self.cache.full().xlora_lock().clone_from(&new_cache);
409 }
410 self.cache.full().xlora_lock()
411 } else {
412 self.cache.full().lock()
413 };
414 let mask = CausalMasker.make_sliding_window_causal_mask_matrix(
415 input_ids,
416 &*cache,
417 Some(self.max_seq_len),
418 self.dtype,
419 self.layers[0].n_head,
420 )?;
421 for (i, layer) in self.layers.iter().enumerate() {
422 if let Some(ref mapper) = self.mapper {
423 xs = mapper.map(xs, i)?;
424 }
425 let residual = &xs;
426 let ys = xs.apply(&layer.attn_norm)?;
427 let ys = layer.forward_attn(
428 &ys,
429 mask.as_ref()
430 .map(|m| m.to_device(xs.device()).unwrap())
431 .as_ref(),
432 seqlen_offsets,
433 &mut cache[i],
434 scalings.clone(),
435 self.xlora_classifier
436 .as_ref()
437 .map(|classifier| classifier.get_global_scaling_weight())
438 .unwrap_or(1.0),
439 is_scaling_pass,
440 flash_params,
441 )?;
442 let ys = (ys + residual)?;
443 let residual = &ys;
444 let ys = ys.apply(&layer.ffn_norm)?;
445 let ys = layer.mlp.forward(
446 &ys,
447 scalings.clone(),
448 self.xlora_classifier
449 .as_ref()
450 .map(|classifier| classifier.get_global_scaling_weight())
451 .unwrap_or(1.0),
452 is_scaling_pass,
453 )?;
454 xs = (ys + residual)?
455 }
456 let xs = xs.to_device(&self.device)?;
457 xs.apply(&self.output_norm)
458 }
459
460 #[allow(clippy::too_many_arguments)]
461 pub fn forward(
462 &self,
463 input_ids: &Tensor,
464 input_ids_full: &Tensor,
465 seqlen_offsets: &[usize],
466 seqlen_offsets_full: &[usize],
467 no_kv_cache: bool,
468 non_granular_state: &Option<NonGranularState>,
469 context_lens: Vec<(usize, usize)>,
470 flash_params: &FlashParams,
471 flash_params_full: &FlashParams,
472 ) -> Result<Tensor> {
473 if self.xlora_classifier.is_some() {
474 let scalings = self.get_scalings(
475 input_ids,
476 input_ids_full,
477 seqlen_offsets,
478 seqlen_offsets_full,
479 no_kv_cache,
480 non_granular_state,
481 &vec![usize::MAX; context_lens.len()],
482 flash_params,
483 flash_params_full,
484 )?;
485
486 if no_kv_cache {
487 extract_logits(
488 &self.output.lora_forward(
489 &self
490 .inner_forward(
491 input_ids_full,
492 seqlen_offsets_full,
493 Some(scalings),
494 true,
495 no_kv_cache,
496 None,
497 flash_params_full,
498 )?
499 .contiguous()?,
500 None,
501 1.0,
502 None,
503 )?,
504 context_lens,
505 )
506 } else {
507 extract_logits(
509 &self.output.lora_forward(
510 &self
511 .inner_forward(
512 input_ids,
513 seqlen_offsets,
514 Some(scalings),
515 true,
516 no_kv_cache,
517 None,
518 flash_params,
519 )?
520 .contiguous()?,
521 None,
522 1.0,
523 None,
524 )?,
525 context_lens,
526 )
527 }
528 } else {
529 extract_logits(
530 &self.output.lora_forward(
531 &self
532 .inner_forward(
533 input_ids,
534 seqlen_offsets,
535 None,
536 false,
537 no_kv_cache,
538 None,
539 flash_params,
540 )?
541 .contiguous()?,
542 None,
543 1.0,
544 None,
545 )?,
546 context_lens,
547 )
548 }
549 }
550}
551
552impl ScalingsMaker for ModelWeights {
553 fn dtype(&self) -> DType {
554 DType::F32 }
556 fn get_cache(&self) -> &EitherCache {
557 &self.cache
558 }
559 fn get_classifier(&self) -> &XLoraClassifier {
560 self.xlora_classifier.as_ref().unwrap()
561 }
562 fn forward(
563 &self,
564 input_ids: &Tensor,
565 seqlen_offsets: &[usize],
566 scalings: Tensor,
567 is_full_pass: bool,
568 no_kv_cache: bool,
569 is_scaling_pass: Option<f64>,
570 _context_lens: &[usize],
571 flash_params: &FlashParams,
572 ) -> Result<Tensor> {
573 self.inner_forward(
574 input_ids,
575 seqlen_offsets,
576 Some(scalings),
577 is_full_pass,
578 no_kv_cache,
579 is_scaling_pass,
580 flash_params,
581 )
582 }
583}