1use super::inputs_processor::DEFAULT_PROMPT_CHUNK_SIZE;
2use super::isq::ImatrixDataSource;
3use super::llg::build_llg_factory;
4use super::{
5 get_model_paths, get_xlora_paths, text_models_inputs_processor::ModelInputs, AdapterKind,
6 CacheManager, GeneralMetadata, Loader, ModelKind, ModelPaths, NormalModel, NormalModelLoader,
7 TokenSource,
8};
9use super::{
10 AnyMoePipelineMixin, CacheManagerMixin, EitherCache, ForwardInputsResult, IsqOrganization,
11 IsqPipelineMixin, MetadataMixin, ModelCategory, PreProcessingMixin,
12};
13use super::{
14 AutoNormalLoader, DeepSeekV2Loader, DeepSeekV3Loader, Gemma2Loader, GemmaLoader, LlamaLoader,
15 MistralLoader, MixtralLoader, NormalLoaderType, Phi2Loader, Phi3Loader, Phi3_5MoELoader,
16 Qwen2Loader, Qwen3Loader, Qwen3MoELoader, Starcoder2Loader,
17};
18use crate::amoe::AnyMoeExpertType;
19use crate::device_map::{self, DeviceMapper};
20use crate::distributed::{self, WorkerTransferData};
21use crate::kv_cache::{FullCacheManager, NormalCacheManager};
22use crate::lora::Ordering;
23use crate::paged_attention::{calculate_cache_config, AttentionImplementation, CacheEngine};
24use crate::pipeline::chat_template::{calculate_eos_tokens, GenerationConfig};
25use crate::pipeline::get_chat_template;
26use crate::pipeline::isq::UqffFullSer;
27use crate::pipeline::loaders::QuantizationConfigShim;
28use crate::pipeline::sampling::sample_and_add_toks;
29use crate::pipeline::text_models_inputs_processor::make_prompt_chunk;
30use crate::pipeline::{ChatTemplate, LocalModelPaths};
31use crate::prefix_cacher::PrefixCacheManagerV2;
32use crate::sequence::Sequence;
33use crate::utils::tokenizer::get_tokenizer;
34use crate::utils::varbuilder_utils::DeviceForLoadTensor;
35use crate::utils::{tokens::get_token, varbuilder_utils::from_mmaped_safetensors};
36use crate::xlora_models::NonGranularState;
37use crate::{
38 api_dir_list, api_get_file, get_mut_arcmutex, get_paths, get_uqff_paths, lora_model_loader,
39 normal_model_loader, normal_model_loader_sharded, xlora_model_loader, DeviceMapSetting,
40 PagedAttentionConfig, Pipeline, Topology, TryIntoDType, GLOBAL_HF_CACHE,
41};
42use anyhow::Result;
43use candle_core::{Device, Tensor, Var};
44use hf_hub::Cache;
45use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
46use indicatif::MultiProgress;
47use mistralrs_quant::log::once_log_info;
48use mistralrs_quant::{AfqLayer, GgufMatMul, HqqLayer, IsqType, QuantizedSerdeType};
49use rand_isaac::Isaac64Rng;
50use regex_automata::meta::Regex;
51use std::any::Any;
52use std::borrow::Cow;
53use std::num::{NonZero, NonZeroUsize};
54use std::path::{Path, PathBuf};
55use std::str::FromStr;
56use std::sync::{Arc, RwLock};
57use std::time::Instant;
58use std::{env, fs};
59use tokenizers::Tokenizer;
60use tokio::sync::Mutex;
61use tracing::{info, warn};
62
63pub struct NormalPipeline {
64 model: Box<dyn NormalModel + Send + Sync>,
65 tokenizer: Arc<Tokenizer>,
66 no_kv_cache: bool,
67 chat_template: Arc<ChatTemplate>,
68 non_granular_state: Option<NonGranularState>,
69 model_id: String,
70 metadata: Arc<GeneralMetadata>,
71 topology: Option<Topology>,
72 silent: bool,
73 organization: IsqOrganization,
74 template_filename: Option<PathBuf>,
76 generation_config: Option<PathBuf>,
77 config: String,
78 imatrix: Option<PathBuf>,
79 mapper: Box<dyn DeviceMapper + Send + Sync>,
80}
81
82pub struct NormalLoader {
84 inner: Box<dyn NormalModelLoader>,
85 model_id: String,
86 config: NormalSpecificConfig,
87 xlora_model_id: Option<String>,
88 lora_adapter_ids: Option<Vec<String>>,
89 kind: ModelKind,
90 xlora_order: Option<Ordering>,
91 no_kv_cache: bool,
92 chat_template: Option<String>,
93 tokenizer_json: Option<String>,
94 tgt_non_granular_index: Option<usize>,
95 token_source: RwLock<Option<TokenSource>>,
96 revision: RwLock<Option<String>>,
97 from_uqff: RwLock<Option<Vec<PathBuf>>>,
98 jinja_explicit: Option<String>,
99 hf_cache_path: Option<PathBuf>,
100}
101
102#[derive(Default)]
103pub struct NormalLoaderBuilder {
105 model_id: Option<String>,
106 config: NormalSpecificConfig,
107 xlora_model_id: Option<String>,
108 lora_adapter_ids: Option<Vec<String>>,
109 kind: ModelKind,
110 xlora_order: Option<Ordering>,
111 no_kv_cache: bool,
112 chat_template: Option<String>,
113 tokenizer_json: Option<String>,
114 tgt_non_granular_index: Option<usize>,
115 jinja_explicit: Option<String>,
116 hf_cache_path: Option<PathBuf>,
117}
118
119#[derive(Clone, Default)]
120pub struct NormalSpecificConfig {
122 pub prompt_chunksize: Option<NonZeroUsize>,
123 pub topology: Option<Topology>,
124 pub organization: IsqOrganization,
125 pub write_uqff: Option<PathBuf>,
126 pub from_uqff: Option<Vec<PathBuf>>,
127 pub imatrix: Option<PathBuf>,
128 pub calibration_file: Option<PathBuf>,
129 pub hf_cache_path: Option<PathBuf>,
130}
131
132impl NormalLoaderBuilder {
133 pub fn new(
134 config: NormalSpecificConfig,
135 chat_template: Option<String>,
136 tokenizer_json: Option<String>,
137 model_id: Option<String>,
138 no_kv_cache: bool,
139 jinja_explicit: Option<String>,
140 ) -> Self {
141 Self {
142 config,
143 chat_template,
144 tokenizer_json,
145 model_id,
146 kind: ModelKind::Normal,
147 jinja_explicit,
148 no_kv_cache,
149 ..Default::default()
150 }
151 }
152
153 fn with_adapter(
154 mut self,
155 xlora_model_id: String,
156 xlora_order: Ordering,
157 no_kv_cache: bool,
158 tgt_non_granular_index: Option<usize>,
159 ) -> Self {
160 self.xlora_model_id = Some(xlora_model_id);
161 self.xlora_order = Some(xlora_order);
162 self.no_kv_cache = no_kv_cache;
163 self.tgt_non_granular_index = tgt_non_granular_index;
164 self.model_id = if let Some(id) = self.model_id {
165 Some(id)
166 } else {
167 info!(
168 "Using adapter base model ID: `{}`",
169 self.xlora_order.as_ref().unwrap().base_model_id
170 );
171 Some(self.xlora_order.as_ref().unwrap().base_model_id.clone())
172 };
173 self
174 }
175
176 pub fn with_xlora(
177 mut self,
178 xlora_model_id: String,
179 xlora_order: Ordering,
180 no_kv_cache: bool,
181 tgt_non_granular_index: Option<usize>,
182 ) -> Self {
183 self.kind = ModelKind::Adapter {
184 adapter: AdapterKind::XLora,
185 };
186 self.with_adapter(
187 xlora_model_id,
188 xlora_order,
189 no_kv_cache,
190 tgt_non_granular_index,
191 )
192 }
193
194 pub fn with_lora(mut self, lora_adapter_ids: Vec<String>) -> Self {
195 self.kind = ModelKind::Adapter {
196 adapter: AdapterKind::Lora,
197 };
198 self.lora_adapter_ids = Some(lora_adapter_ids);
199 self
200 }
201
202 pub fn hf_cache_path(mut self, hf_cache_path: PathBuf) -> Self {
203 self.hf_cache_path = Some(hf_cache_path);
204 self
205 }
206
207 pub fn build(self, loader_tp: Option<NormalLoaderType>) -> anyhow::Result<Box<dyn Loader>> {
210 let loader: Box<dyn NormalModelLoader> = match loader_tp {
211 Some(NormalLoaderType::Mistral) => Box::new(MistralLoader),
212 Some(NormalLoaderType::Gemma) => Box::new(GemmaLoader),
213 Some(NormalLoaderType::Llama) => Box::new(LlamaLoader),
214 Some(NormalLoaderType::Mixtral) => Box::new(MixtralLoader),
215 Some(NormalLoaderType::Phi2) => Box::new(Phi2Loader),
216 Some(NormalLoaderType::Phi3) => Box::new(Phi3Loader),
217 Some(NormalLoaderType::Qwen2) => Box::new(Qwen2Loader),
218 Some(NormalLoaderType::Gemma2) => Box::new(Gemma2Loader),
219 Some(NormalLoaderType::Starcoder2) => Box::new(Starcoder2Loader),
220 Some(NormalLoaderType::Phi3_5MoE) => Box::new(Phi3_5MoELoader),
221 Some(NormalLoaderType::DeepSeekV2) => Box::new(DeepSeekV2Loader),
222 Some(NormalLoaderType::DeepSeekV3) => Box::new(DeepSeekV3Loader),
223 Some(NormalLoaderType::Qwen3) => Box::new(Qwen3Loader),
224 Some(NormalLoaderType::Qwen3Moe) => Box::new(Qwen3MoELoader),
225 None => Box::new(AutoNormalLoader),
226 };
227 Ok(Box::new(NormalLoader {
228 inner: loader,
229 model_id: self.model_id.unwrap(),
230 config: self.config,
231 xlora_model_id: self.xlora_model_id,
232 lora_adapter_ids: self.lora_adapter_ids,
233 kind: self.kind,
234 xlora_order: self.xlora_order,
235 no_kv_cache: self.no_kv_cache,
236 chat_template: self.chat_template,
237 tokenizer_json: self.tokenizer_json,
238 tgt_non_granular_index: self.tgt_non_granular_index,
239 jinja_explicit: self.jinja_explicit,
240 token_source: RwLock::new(None),
241 revision: RwLock::new(None),
242 from_uqff: RwLock::new(None),
243 hf_cache_path: self.hf_cache_path,
244 }))
245 }
246}
247
248impl Loader for NormalLoader {
249 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
250 fn load_model_from_hf(
251 &self,
252 revision: Option<String>,
253 token_source: TokenSource,
254 dtype: &dyn TryIntoDType,
255 device: &Device,
256 silent: bool,
257 mapper: DeviceMapSetting,
258 in_situ_quant: Option<IsqType>,
259 paged_attn_config: Option<PagedAttentionConfig>,
260 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
261 let cache = self
262 .hf_cache_path
263 .clone()
264 .map(Cache::new)
265 .unwrap_or_default();
266 GLOBAL_HF_CACHE.get_or_init(|| cache);
267
268 let paths: anyhow::Result<Box<dyn ModelPaths>> = get_paths!(
269 LocalModelPaths,
270 &token_source,
271 revision.clone(),
272 self,
273 None,
274 None,
275 silent,
276 self.config.from_uqff.is_some()
277 );
278 if let Some(from_uqff) = self.config.from_uqff.clone() {
279 *self.from_uqff.write().unwrap() = Some(get_uqff_paths!(&from_uqff, self, silent));
280 }
281 *self
282 .token_source
283 .write()
284 .expect("Failed to write to token source") = Some(token_source);
285 *self.revision.write().expect("Failed to write to revision") = revision;
286 self.load_model_from_path(
287 &paths?,
288 dtype,
289 device,
290 silent,
291 mapper,
292 in_situ_quant,
293 paged_attn_config,
294 )
295 }
296
297 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
298 fn load_model_from_path(
299 &self,
300 paths: &Box<dyn ModelPaths>,
301 dtype: &dyn TryIntoDType,
302 device: &Device,
303 silent: bool,
304 mut mapper: DeviceMapSetting,
305 mut in_situ_quant: Option<IsqType>,
306 mut paged_attn_config: Option<PagedAttentionConfig>,
307 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
308 let config = std::fs::read_to_string(paths.get_config_filename())?;
309
310 if !self.inner.supports_paged_attention(&config)? {
311 paged_attn_config = None;
312 }
313
314 let prompt_chunksize = self
316 .config
317 .prompt_chunksize
318 .unwrap_or(DEFAULT_PROMPT_CHUNK_SIZE.try_into().unwrap())
319 .get();
320
321 info!("Prompt chunk size is {prompt_chunksize}.",);
322
323 let use_nccl = mistralrs_quant::distributed::use_nccl();
324
325 let available_devices = if let Ok(payload) = env::var(distributed::IS_DAEMON_FLAG) {
326 let payload: WorkerTransferData = serde_json::from_str(&payload)?;
327 let WorkerTransferData::Init { id: _, worker_rank } = payload;
328 vec![candle_core::Device::new_cuda(worker_rank + 1)?]
329 } else if use_nccl {
330 vec![candle_core::Device::new_cuda(0)?]
331 } else {
332 device_map::get_all_similar_devices(device)?
333 };
334 let device = if use_nccl {
335 available_devices[0].clone()
336 } else {
337 device.clone()
338 };
339
340 if use_nccl {
342 mapper = DeviceMapSetting::DummyNccl {
343 nm_device: available_devices[0].clone(),
344 };
345 } else if let DeviceMapSetting::Auto(params) = mapper.clone() {
346 let dtype = dtype.try_into_dtype(&available_devices.iter().collect::<Vec<_>>())?;
348
349 if QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)? != 1 {
351 in_situ_quant = None;
352 }
353
354 let (layer_sizes_in_bytes, non_mapped_size_in_bytes, total_model_size_in_bytes) =
357 if let Some(serialized) = &*self.from_uqff.read().unwrap() {
358 let weight_pack_factor = {
359 let ser_artifacts = unsafe {
360 candle_core::safetensors::MmapedSafetensors::multi(serialized)?
361 };
362 let mut total_pack_factors = 0;
363 let total_tensors = ser_artifacts.tensors().len();
364 for (_, artifact) in ser_artifacts.tensors() {
365 let artifact = artifact.data();
366 let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
368 let pack_factor = match QuantizedSerdeType::try_from(isq_type as usize)?
369 {
370 QuantizedSerdeType::Hqq => {
371 HqqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
372 .pack_factor(dtype)
373 }
374 QuantizedSerdeType::Gguf => {
375 GgufMatMul::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
376 .pack_factor(dtype)
377 }
378 QuantizedSerdeType::Fp8 => IsqType::F8E4M3.pack_factor(dtype),
379 QuantizedSerdeType::Unquant => 1,
380 QuantizedSerdeType::Afq => {
381 AfqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
382 .pack_factor(dtype)
383 }
384 };
385 total_pack_factors += pack_factor;
386 }
387
388 total_pack_factors / total_tensors
389 };
390
391 let layer_sizes_in_bytes =
392 self.inner
393 .layer_sizes_in_bytes(&config, dtype, weight_pack_factor)?;
394 let non_mapped_size_in_bytes =
395 self.inner
396 .non_mapped_size_in_bytes(&config, dtype, weight_pack_factor)?;
397 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
398 (
399 layer_sizes_in_bytes,
400 non_mapped_size_in_bytes,
401 layer_sizes_sum + non_mapped_size_in_bytes,
402 )
403 } else if let Some(isq) = in_situ_quant {
404 let weight_pack_factor = isq.pack_factor(dtype);
405 let layer_sizes_in_bytes =
406 self.inner
407 .layer_sizes_in_bytes(&config, dtype, weight_pack_factor)?;
408 let non_mapped_size_in_bytes =
409 self.inner
410 .non_mapped_size_in_bytes(&config, dtype, weight_pack_factor)?;
411 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
412 (
413 layer_sizes_in_bytes,
414 non_mapped_size_in_bytes,
415 layer_sizes_sum + non_mapped_size_in_bytes,
416 )
417 } else {
418 let weight_pack_factor =
420 QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)?;
421 let layer_sizes_in_bytes =
422 self.inner
423 .layer_sizes_in_bytes(&config, dtype, weight_pack_factor)?;
424 let non_mapped_size_in_bytes =
425 self.inner
426 .non_mapped_size_in_bytes(&config, dtype, weight_pack_factor)?;
427 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
428 (
429 layer_sizes_in_bytes,
430 non_mapped_size_in_bytes,
431 layer_sizes_sum + non_mapped_size_in_bytes,
432 )
433 };
434
435 let new = self.inner.get_device_layers(
436 &config,
437 self.inner.num_layers(&config)?,
438 layer_sizes_in_bytes,
439 non_mapped_size_in_bytes,
440 total_model_size_in_bytes,
441 &available_devices,
442 dtype,
443 ¶ms,
444 prompt_chunksize,
445 paged_attn_config.as_ref(),
446 )?;
447 mapper = DeviceMapSetting::Map(new);
448 }
449
450 let pipeline_mapper = mapper.into_mapper(
451 self.inner.num_layers(&config)?,
452 &device,
453 self.config.topology.as_ref(),
454 )?;
455 let mapper = mapper.into_mapper(
456 self.inner.num_layers(&config)?,
457 &device,
458 self.config.topology.as_ref(),
459 )?;
460 let mut layer_devices = Vec::new();
461 for layer in 0..self.inner.num_layers(&config)? {
462 let device = mapper.device_for(layer, false).cloned();
463 layer_devices.push(device);
464 }
465 let dtype = mapper.get_min_dtype(dtype)?;
466
467 let mapping_uses_cpu = mapper.get_unique_devices().iter().any(Device::is_cpu);
470 if mapping_uses_cpu {
471 warn!("Device mapping contains a mix of GPU and CPU. There is no CPU support for PagedAttention, disabling PagedAttention.");
472 paged_attn_config = None;
473 }
474
475 info!("Model config: {:?}", self.inner.get_config_repr(&config)?);
476 if crate::using_flash_attn() {
477 once_log_info("FlashAttention is enabled.");
478 }
479
480 let mut loading_isq = if self.config.imatrix.is_none()
482 && self.config.calibration_file.is_none()
483 && !device.is_cuda()
484 && self.config.write_uqff.is_none()
485 && in_situ_quant.is_some()
486 {
487 let predicates = if matches!(self.config.organization, IsqOrganization::MoeExpertsOnly)
488 {
489 self.inner.immediate_isq_predicates_moqe(&config)?
490 } else {
491 self.inner.immediate_isq_predicates(&config)?
492 };
493 info!("Applying ISQ to {in_situ_quant:?}");
494 if predicates.is_empty() {
495 warn!("No predicates for this model and ISQ setting detected. ISQ will not be applied to any weights!");
496 }
497 mistralrs_quant::set_immediate_isq(in_situ_quant, predicates);
498 false
499 } else {
500 in_situ_quant.is_some()
501 };
502
503 if let Some(ref topology) = self.config.topology {
504 loading_isq |= topology
505 .0
506 .iter()
507 .any(|layer| layer.as_ref().is_some_and(|layer| layer.isq.is_some()));
508 }
509
510 if self.config.imatrix.is_some() && self.config.calibration_file.is_some() {
511 anyhow::bail!(
512 "`imatrix` and `calibration_file` were both specified, this is not allowed."
513 );
514 }
515
516 let load_device = if !loading_isq || self.config.calibration_file.is_some() {
518 loading_isq = false;
519 device.clone()
520 } else {
521 Device::Cpu
522 };
523
524 let is_xlora = self.kind.is_adapted_and(|a| a.is_x_lora());
525
526 let attention_mechanism = if paged_attn_config.is_some() {
527 AttentionImplementation::PagedAttention
528 } else {
529 AttentionImplementation::Eager
530 };
531
532 let multi_progress = Arc::new(MultiProgress::new());
533
534 let mut model = if use_nccl {
535 let (mapper, sharded_vb) = distributed::prepare_distributed_mapper(
536 dtype,
537 &device,
538 &available_devices,
539 silent,
540 &config,
541 loading_isq,
542 self.config.from_uqff.is_some(),
543 self.config.organization,
544 &*self.inner,
545 paths.as_ref(),
546 )?;
547
548 match self.kind {
550 ModelKind::Normal => normal_model_loader_sharded!(
551 sharded_vb,
552 config,
553 self.inner,
554 mapper,
555 loading_isq,
556 device.clone(),
557 attention_mechanism,
558 multi_progress.clone(),
559 ),
560 ModelKind::Adapter {
561 adapter: AdapterKind::XLora,
562 } => xlora_model_loader!(
563 paths,
564 Some(dtype),
565 &load_device,
566 layer_devices.clone(),
567 config,
568 self.inner,
569 silent,
570 mapper,
571 loading_isq,
572 device.clone(),
573 multi_progress.clone(),
574 ),
575 ModelKind::Adapter {
576 adapter: AdapterKind::Lora,
577 } => lora_model_loader!(
578 paths,
579 Some(dtype),
580 &load_device,
581 layer_devices.clone(),
582 config,
583 self.inner,
584 silent,
585 mapper,
586 loading_isq,
587 self.config.from_uqff.is_some(),
588 device.clone(),
589 attention_mechanism,
590 matches!(self.config.organization, IsqOrganization::MoeExpertsOnly),
591 multi_progress.clone(),
592 ),
593 _ => unreachable!(),
594 }
595 } else {
596 match self.kind {
597 ModelKind::Normal => normal_model_loader!(
598 paths,
599 Some(dtype),
600 &load_device,
601 layer_devices.clone(),
602 config,
603 self.inner,
604 silent,
605 mapper,
606 loading_isq,
607 self.config.from_uqff.is_some(),
608 device.clone(),
609 attention_mechanism,
610 matches!(self.config.organization, IsqOrganization::MoeExpertsOnly),
611 multi_progress.clone(),
612 ),
613 ModelKind::Adapter {
614 adapter: AdapterKind::XLora,
615 } => xlora_model_loader!(
616 paths,
617 Some(dtype),
618 &load_device,
619 layer_devices.clone(),
620 config,
621 self.inner,
622 silent,
623 mapper,
624 loading_isq,
625 device.clone(),
626 multi_progress.clone(),
627 ),
628 ModelKind::Adapter {
629 adapter: AdapterKind::Lora,
630 } => lora_model_loader!(
631 paths,
632 Some(dtype),
633 &load_device,
634 layer_devices.clone(),
635 config,
636 self.inner,
637 silent,
638 mapper,
639 loading_isq,
640 self.config.from_uqff.is_some(),
641 device.clone(),
642 attention_mechanism,
643 matches!(self.config.organization, IsqOrganization::MoeExpertsOnly),
644 multi_progress.clone(),
645 ),
646 _ => unreachable!(),
647 }
648 };
649
650 let tokenizer = get_tokenizer(paths.get_tokenizer_filename(), None)?;
651 let gen_conf: Option<GenerationConfig> = paths.get_gen_conf_filename().map(|f| {
652 serde_json::from_str(&fs::read_to_string(f).unwrap())
653 .expect("bos_token_id/eos_token_id missing in generation_config.json")
654 });
655
656 let chat_template_explicit = paths
657 .get_chat_template_explicit()
658 .as_ref()
659 .map(|x| x.to_string_lossy().to_string());
660 let chat_template = get_chat_template(
661 paths,
662 self.jinja_explicit.as_ref(),
663 chat_template_explicit.as_ref(),
664 self.chat_template.as_ref(),
665 None,
666 );
667
668 if let Some(calibration_file) = &self.config.calibration_file {
669 let calibration_data = std::fs::read_to_string(calibration_file)?;
670 let tokens = tokenizer
672 .encode_fast(calibration_data, false)
673 .map_err(anyhow::Error::msg)?
674 .get_ids()
675 .to_vec();
676 info!(
677 "Collecting imatrix from calibration file `{}` of {} tokens.",
678 calibration_file.display(),
679 tokens.len()
680 );
681 let bos_toks = chat_template.bos_tok().map(|b| vec![b]).unwrap_or_default();
682 let bos_tok_id = tokenizer
683 .token_to_id(&bos_toks[0])
684 .expect("Somehow the bos token is not present.");
685
686 match self.config.organization {
687 IsqOrganization::Default => model.begin_track_stats()?,
688 IsqOrganization::MoeExpertsOnly => model.begin_track_stats_moe_experts_only()?,
689 }
690
691 const CHUNK_SIZE: usize = 1024;
692 let n_chunks = tokens.len().div_ceil(CHUNK_SIZE);
693 let start = Instant::now();
694 for (i, chunk) in tokens.chunks(CHUNK_SIZE).enumerate() {
695 let chunk = [vec![bos_tok_id], chunk.to_vec()].concat();
696 let chunk_len = chunk.len();
697
698 let start = Instant::now();
699 let inputs = make_prompt_chunk(
700 0,
701 vec![&chunk],
702 &[0],
703 &load_device,
704 None,
705 false,
706 None,
707 Some(pipeline_mapper.as_ref()),
708 )?;
709
710 model.forward(
711 &inputs.input.to_device(model.device())?,
712 &inputs.positions,
713 inputs.context_lens.clone(),
714 inputs.position_ids.clone(),
715 None,
716 &inputs.flash_meta.clone(),
717 )?;
718
719 match model.cache_mut() {
720 EitherCache::Full(full) => {
721 for layer in &mut *full.lock() {
722 *layer = None
723 }
724 }
725 EitherCache::Normal(normal) => {
726 for layer in &mut *normal.lock().unwrap().0 {
727 layer.reset();
728 }
729 }
730 }
731
732 let end = Instant::now();
733 info!(
734 "Processed chunk {}/{n_chunks} ({chunk_len} tokens), {:.2}s",
735 i + 1,
736 end.duration_since(start).as_secs_f32()
737 );
738 }
739 load_device.synchronize()?;
740 let end = Instant::now();
741 info!(
742 "Finished collecting imatrix in {:.2}s",
743 end.duration_since(start).as_secs_f32()
744 );
745 }
746
747 if (loading_isq || self.config.topology.is_some()) && self.config.from_uqff.is_none() {
749 let imatrix_source = match (
750 self.config.imatrix.as_ref(),
751 self.config.calibration_file.is_some(),
752 ) {
753 (None, false) => None,
754 (Some(file), false) => Some(ImatrixDataSource::File(file)),
755 (None, true) => Some(ImatrixDataSource::Collected),
756 (Some(_), true) => unreachable!(),
757 };
758
759 info!("Applying ISQ to all ranks.");
760
761 let multi_progress = Arc::new(MultiProgress::new());
762
763 model.quantize(
764 in_situ_quant,
765 model.device().clone(),
766 self.config.topology.as_ref(),
767 silent,
768 imatrix_source,
769 self.config.organization,
770 self.config.write_uqff.as_ref(),
771 UqffFullSer {
772 tokenizer: &tokenizer,
773 template_filename: paths.get_template_filename(),
774 generation_config: paths.get_gen_conf_filename(),
775 config: config.clone(),
776 processor_filename: &None,
777 preprocessor_filename: &None,
778 },
779 multi_progress.clone(),
780 )?;
781 } else if let Some(from_uqff) = &*self.from_uqff.read().unwrap() {
782 model.load_from_artifacts(
783 device.clone(),
784 self.config.topology.as_ref(),
785 silent,
786 from_uqff,
787 )?;
788 }
789
790 let paged_attn_config = if matches!(
791 self.kind,
792 ModelKind::Adapter {
793 adapter: AdapterKind::XLora
794 }
795 ) {
796 warn!(
797 "Adapter parallel_models do not currently support PagedAttention, running without"
798 );
799 None
800 } else {
801 paged_attn_config
802 };
803
804 let (cache_config, cache_engine) = if let Some(paged_attn_config) = paged_attn_config {
805 let cache_config = calculate_cache_config(
806 paged_attn_config.mem_gpu,
807 paged_attn_config.mem_cpu,
808 paged_attn_config.block_size,
809 dtype,
810 model.config(),
811 &device,
812 &pipeline_mapper
813 .get_unique_devices()
814 .into_iter()
815 .map(Some)
816 .collect::<Vec<_>>(),
817 silent,
818 )?;
819
820 let mut layer_devices = Vec::new();
821 for layer in 0..self.inner.num_layers(&config)? {
822 let device = model.get_layers().1.device_for(layer, false).cloned();
823 layer_devices.push(device);
824 }
825 let cache_engine = CacheEngine::new(
826 model.config(),
827 &cache_config,
828 dtype,
829 model.device(),
830 layer_devices.clone(),
831 )?;
832
833 (Some(cache_config), Some(cache_engine))
834 } else {
835 (None, None)
836 };
837
838 let max_seq_len = model.max_seq_len();
839 let llg_factory = build_llg_factory(tokenizer.clone())?;
840 let num_hidden_layers = match model.cache() {
841 EitherCache::Full(full) => full.lock().len(),
842 EitherCache::Normal(normal) => normal.lock().unwrap().0.len(),
843 };
844 let eos = calculate_eos_tokens(&chat_template, gen_conf, &tokenizer);
845 let sliding_window = model.config().sliding_window;
846 let model_metadata = Arc::new(model.config().clone());
847
848 Ok(Arc::new(Mutex::new(NormalPipeline {
849 model,
850 tokenizer: tokenizer.into(),
851 no_kv_cache: self.no_kv_cache,
852 chat_template: Arc::new(chat_template),
853 non_granular_state: self.tgt_non_granular_index.map(|tgt_non_granular_index| {
854 NonGranularState {
855 non_granular_index: Arc::new(Mutex::new(0)),
856 tgt_non_granular_index,
857 }
858 }),
859 model_id: self.model_id.clone(),
860 metadata: Arc::new(GeneralMetadata {
861 max_seq_len,
862 llg_factory: Some(llg_factory),
863 no_kv_cache: self.no_kv_cache,
864 no_prefix_cache: is_xlora,
865 num_hidden_layers,
866 eos_tok: eos,
867 kind: self.kind.clone(),
868 is_xlora,
869 activation_dtype: dtype,
870 sliding_window,
871 cache_config,
872 cache_engine,
873 prompt_chunksize: Some(NonZero::new(prompt_chunksize).unwrap()),
874 model_metadata: Some(model_metadata),
875 }),
876 topology: self.config.topology.clone(),
877 silent,
878 organization: self.config.organization,
879 template_filename: paths.get_template_filename().clone(),
880 generation_config: paths.get_gen_conf_filename().cloned(),
881 config,
882 imatrix: self.config.imatrix.clone(),
883 mapper: pipeline_mapper,
884 })))
885 }
886
887 fn get_id(&self) -> String {
888 self.model_id.clone()
889 }
890
891 fn get_kind(&self) -> ModelKind {
892 self.kind.clone()
893 }
894}
895
896impl PreProcessingMixin for NormalPipeline {
897 fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
898 Some(self.chat_template.clone())
899 }
900 fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
901 None
902 }
903}
904
905impl IsqPipelineMixin for NormalPipeline {
906 fn re_isq_model(&mut self, dtype: IsqType) -> Result<()> {
907 let device = self.device().clone();
908 let multi_progress = Arc::new(MultiProgress::new());
909 self.model.quantize(
910 Some(dtype),
911 device.clone(),
912 self.topology.as_ref(),
913 self.silent,
914 self.imatrix.as_ref().map(ImatrixDataSource::File),
915 self.organization,
916 None,
917 UqffFullSer {
918 tokenizer: &self.tokenizer,
919 template_filename: &self.template_filename,
920 generation_config: self.generation_config.as_ref(),
921 config: self.config.clone(),
922 processor_filename: &None,
923 preprocessor_filename: &None,
924 },
925 multi_progress.clone(),
926 )?;
927 Ok(())
928 }
929}
930
931impl CacheManagerMixin for NormalPipeline {
932 fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
933 if matches!(self.model.cache(), EitherCache::Full(_)) {
934 FullCacheManager.clone_in_cache(self, seqs, false)
935 } else {
936 NormalCacheManager.clone_in_cache(self, seqs, false)
937 }
938 }
939 fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
940 if matches!(self.model.cache(), EitherCache::Full(_)) {
941 FullCacheManager.clone_out_cache(self, seqs, false)
942 } else {
943 NormalCacheManager.clone_out_cache(self, seqs, false)
944 }
945 }
946 fn set_none_cache(
947 &self,
948 seqs: &mut [&mut Sequence],
949 reset_non_granular: bool,
950 modify_draft_cache: bool,
951 load_preallocated_cache: bool,
952 ) {
953 if matches!(self.model.cache(), EitherCache::Full(_)) {
954 FullCacheManager.set_none_cache(self, seqs, modify_draft_cache, false);
955 } else {
956 NormalCacheManager.set_none_cache(
957 self,
958 seqs,
959 modify_draft_cache,
960 load_preallocated_cache,
961 );
962 }
963 if reset_non_granular {
964 self.reset_non_granular_state()
965 }
966 }
967 fn cache(&self) -> &EitherCache {
968 self.model.cache()
969 }
970}
971
972impl MetadataMixin for NormalPipeline {
973 fn device(&self) -> Device {
974 self.model.device().clone()
975 }
976 fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
977 Some(self.tokenizer.clone())
978 }
979 fn name(&self) -> String {
980 self.model_id.clone()
981 }
982 fn reset_non_granular_state(&self) {
983 if let Some(s) = self.non_granular_state.as_ref() {
984 *self.cache().full().get_scalings_cache() = None;
985 *get_mut_arcmutex!(s.non_granular_index) = 0;
986 }
987 }
988 fn get_metadata(&self) -> Arc<GeneralMetadata> {
989 self.metadata.clone()
990 }
991 fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
992 Some(&*self.mapper)
993 }
994}
995
996#[async_trait::async_trait]
997impl Pipeline for NormalPipeline {
998 fn forward_inputs(
999 &mut self,
1000 inputs: Box<dyn Any>,
1001 return_raw_logits: bool,
1002 ) -> Result<ForwardInputsResult, candle_core::Error> {
1003 let ModelInputs {
1004 input_ids,
1005 input_ids_full,
1006 seqlen_offsets,
1007 seqlen_offsets_full,
1008 context_lens,
1009 position_ids,
1010 paged_attn_meta,
1011 flash_meta,
1012 flash_meta_full,
1013 } = *inputs.downcast().expect("Downcast failed.");
1014 let metadata = self.get_metadata();
1015 let paged_attn_meta = match (&metadata.cache_engine, &paged_attn_meta) {
1016 (Some(cache_engine), Some(meta)) => Some((cache_engine, meta)),
1017 (Some(_), None) => {
1018 candle_core::bail!("Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.")
1020 }
1021 (None, Some(_)) => {
1022 candle_core::bail!("Forward step got a PagedAttention input metadata but there is no cache engine. Please raise an issue.")
1024 }
1025 (None, None) => None,
1026 };
1027 let logits = match self.model.is_xlora() {
1028 false => {
1029 let paged_attn_meta = paged_attn_meta
1030 .as_ref()
1031 .map(|meta| (meta.0.get_kv_cache().clone(), meta.1.clone()));
1032
1033 self.model.forward(
1034 &input_ids,
1035 &seqlen_offsets,
1036 context_lens,
1037 position_ids,
1038 paged_attn_meta.as_ref().map(|(a, b)| (a.clone(), b)),
1039 &flash_meta,
1040 )?
1041 }
1042 true => self.model.xlora_forward(
1043 &input_ids,
1044 input_ids_full.as_ref().unwrap_or(&input_ids),
1045 &seqlen_offsets,
1046 seqlen_offsets_full.as_ref().unwrap_or(&seqlen_offsets),
1047 self.no_kv_cache,
1048 &self.non_granular_state,
1049 context_lens,
1050 position_ids,
1051 &flash_meta,
1052 flash_meta_full.as_ref().unwrap_or(&flash_meta),
1053 )?,
1054 };
1055 if return_raw_logits {
1056 Ok(ForwardInputsResult::RawLogits { logits })
1057 } else {
1058 Ok(ForwardInputsResult::CausalGeneration { logits })
1059 }
1060 }
1061 async fn sample_causal_gen(
1062 &self,
1063 seqs: &mut [&mut Sequence],
1064 logits: Vec<Tensor>,
1065 prefix_cacher: &mut PrefixCacheManagerV2,
1066 disable_eos_stop: bool,
1067 rng: Arc<std::sync::Mutex<Isaac64Rng>>,
1068 ) -> Result<(), candle_core::Error> {
1069 sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
1070 }
1071 fn category(&self) -> ModelCategory {
1072 ModelCategory::Text
1073 }
1074}
1075
1076impl AnyMoePipelineMixin for NormalPipeline {
1077 fn amoe_finish_training(&mut self, gate_model_id: Option<String>) -> candle_core::Result<()> {
1078 self.model.finish_training(gate_model_id)
1079 }
1080 fn amoe_layer_vars(&self) -> Vec<Vec<Var>> {
1081 self.model.get_vars()
1082 }
1083 fn amoe_base_model_trainable_params(&self) -> usize {
1084 self.model.trainable_params()
1085 }
1086 fn amoe_take_cached_gating_outputs(&mut self) -> Vec<Tensor> {
1087 self.model.take_cached_gating_outputs()
1088 }
1089 fn amoe_create_layers(
1090 &mut self,
1091 model_ids: Vec<String>,
1092 token: &TokenSource,
1093 revision: Option<String>,
1094 match_regex: &str,
1095 config: crate::amoe::AnyMoeConfig,
1096 dtype: candle_core::DType,
1097 dev: &Device,
1098 (prefix, mlp): (String, String),
1099 layers: Vec<usize>,
1100 expert_type: AnyMoeExpertType,
1101 silent: bool,
1102 gate_model_id: Option<String>,
1103 ) -> candle_core::Result<()> {
1104 let mut vbs = Vec::new();
1105 let regex = Regex::new(match_regex).map_err(candle_core::Error::msg)?;
1107 for model_id in model_ids {
1108 let model_id_str = &model_id;
1109 let model_id = Path::new(&model_id);
1110
1111 let api = {
1112 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1113 let mut api = ApiBuilder::from_cache(cache)
1114 .with_progress(!silent)
1115 .with_token(get_token(token).map_err(candle_core::Error::msg)?);
1116 if let Ok(x) = std::env::var("HF_HUB_CACHE") {
1117 api = api.with_cache_dir(x.into());
1118 }
1119 api.build().map_err(candle_core::Error::msg)?
1120 };
1121 let revision = revision.clone().unwrap_or("main".to_string());
1122 let api = api.repo(Repo::with_revision(
1123 model_id_str.clone(),
1124 RepoType::Model,
1125 revision.clone(),
1126 ));
1127
1128 let mut filenames = vec![];
1129 for rfilename in api_dir_list!(api, model_id).filter(|x| x.ends_with(".safetensors")) {
1130 filenames.push(api_get_file!(api, &rfilename, model_id));
1131 }
1132
1133 let regex = regex.clone();
1134 let match_regex_clone = match_regex.to_string();
1135 let layers_clone = layers.clone();
1136 let vb = from_mmaped_safetensors(
1137 filenames,
1138 vec![],
1139 Some(dtype),
1140 dev,
1141 vec![None],
1142 silent,
1143 None,
1144 move |key| {
1145 if regex.is_match(&key) {
1146 let last_layer_idx = key.find(&match_regex_clone).unwrap() - 1;
1149 let first_layer_idx = key[..last_layer_idx].rfind('.').unwrap();
1150 let layer_n = key[first_layer_idx + 1..last_layer_idx]
1151 .parse::<usize>()
1152 .unwrap();
1153 layers_clone.contains(&layer_n) || layers_clone.is_empty()
1154 } else {
1155 false
1156 }
1157 },
1158 Arc::new(|_| DeviceForLoadTensor::Base),
1159 )?;
1160 vbs.push(vb);
1161 }
1162
1163 let gate_vb = if let Some(gate_model_id) = gate_model_id {
1164 let model_id_str = &gate_model_id;
1165 let model_id = Path::new(&gate_model_id);
1166
1167 let api = {
1168 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1169 let mut api = ApiBuilder::from_cache(cache)
1170 .with_progress(!silent)
1171 .with_token(get_token(token).map_err(candle_core::Error::msg)?);
1172 if let Ok(x) = std::env::var("HF_HUB_CACHE") {
1173 api = api.with_cache_dir(x.into());
1174 }
1175 api.build().map_err(candle_core::Error::msg)?
1176 };
1177 let revision = revision.clone().unwrap_or("main".to_string());
1178 let api = api.repo(Repo::with_revision(
1179 model_id_str.clone(),
1180 RepoType::Model,
1181 revision.clone(),
1182 ));
1183
1184 let mut gate_filenames = vec![];
1185 for rfilename in api_dir_list!(api, model_id).filter(|x| x.ends_with(".safetensors")) {
1186 gate_filenames.push(api_get_file!(api, &rfilename, model_id));
1187 }
1188 assert_eq!(
1189 gate_filenames.len(),
1190 1,
1191 "Gate model ID must contain only one .safetensors file"
1192 );
1193
1194 let vb = from_mmaped_safetensors(
1195 gate_filenames.clone(),
1196 vec![],
1197 Some(dtype),
1198 dev,
1199 vec![None],
1200 silent,
1201 None,
1202 |_| true,
1203 Arc::new(|_| DeviceForLoadTensor::Base),
1204 )?;
1205 info!(
1206 "Loaded gating layers from `{}`",
1207 gate_filenames[0].display()
1208 );
1209 Some(vb)
1210 } else {
1211 None
1212 };
1213
1214 self.model.create_anymoe_layers(
1215 vbs.clone(),
1216 config.clone(),
1217 (prefix.clone(), mlp.clone()),
1218 layers.clone(),
1219 expert_type.clone(),
1220 gate_vb.clone(),
1221 )?;
1222
1223 Ok(())
1224 }
1225 fn amoe_supported(&self) -> bool {
1226 self.model.amoe_supported()
1227 }
1228}