1use super::isq::ImatrixDataSource;
2use super::llg::build_llg_factory;
3use super::{
4 get_model_paths, get_xlora_paths, text_models_inputs_processor::ModelInputs, AdapterKind,
5 CacheManager, GeneralMetadata, Loader, ModelKind, ModelPaths, NormalModel, NormalModelLoader,
6 TokenSource,
7};
8use super::{
9 AnyMoePipelineMixin, CacheManagerMixin, EitherCache, ForwardInputsResult, IsqOrganization,
10 IsqPipelineMixin, MetadataMixin, ModelCategory, PreProcessingMixin,
11};
12use super::{
13 AutoNormalLoader, DeepSeekV2Loader, DeepSeekV3Loader, GLM4Loader, Gemma2Loader, GemmaLoader,
14 LlamaLoader, MistralLoader, MixtralLoader, NormalLoaderType, Phi2Loader, Phi3Loader,
15 Phi3_5MoELoader, Qwen2Loader, Qwen3Loader, Qwen3MoELoader, SmolLm3Loader, Starcoder2Loader,
16};
17use crate::amoe::AnyMoeExpertType;
18use crate::attention::ATTENTION_CHUNK_SIZE;
19use crate::device_map::{self, DeviceMapper};
20use crate::distributed::{self, WorkerTransferData};
21use crate::kv_cache::{FullCacheManager, NormalCacheManager};
22use crate::lora::Ordering;
23use crate::paged_attention::{calculate_cache_config, AttentionImplementation, CacheEngine};
24use crate::pipeline::chat_template::{calculate_eos_tokens, GenerationConfig};
25use crate::pipeline::isq::UqffFullSer;
26use crate::pipeline::loaders::auto_device_map;
27use crate::pipeline::loaders::QuantizationConfigShim;
28use crate::pipeline::sampling::sample_and_add_toks;
29use crate::pipeline::text_models_inputs_processor::make_prompt_chunk;
30use crate::pipeline::{get_chat_template, Modalities, SupportedModality};
31use crate::pipeline::{ChatTemplate, LocalModelPaths};
32use crate::prefix_cacher::PrefixCacheManagerV2;
33use crate::sequence::Sequence;
34use crate::utils::tokenizer::get_tokenizer;
35use crate::utils::varbuilder_utils::DeviceForLoadTensor;
36use crate::utils::{tokens::get_token, varbuilder_utils::from_mmaped_safetensors};
37use crate::xlora_models::NonGranularState;
38use crate::{
39 api_dir_list, api_get_file, get_mut_arcmutex, get_paths, get_uqff_paths, lora_model_loader,
40 normal_model_loader, normal_model_loader_sharded, xlora_model_loader, DeviceMapSetting,
41 PagedAttentionConfig, Pipeline, Topology, TryIntoDType, GLOBAL_HF_CACHE,
42};
43use anyhow::Result;
44use candle_core::{Device, Tensor, Var};
45use hf_hub::Cache;
46use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
47use indicatif::MultiProgress;
48use mistralrs_quant::log::once_log_info;
49use mistralrs_quant::{
50 AfqLayer, GgufMatMul, HqqLayer, ImmediateIsqOverride, IsqType, QuantizedSerdeType,
51};
52use rand_isaac::Isaac64Rng;
53use regex_automata::meta::Regex;
54use std::any::Any;
55use std::borrow::Cow;
56use std::path::{Path, PathBuf};
57use std::str::FromStr;
58use std::sync::{Arc, RwLock};
59use std::time::Instant;
60use std::{env, fs};
61use tokenizers::Tokenizer;
62use tokio::sync::Mutex;
63use tracing::{info, warn};
64
65pub struct NormalPipeline {
66 model: Box<dyn NormalModel + Send + Sync>,
67 tokenizer: Arc<Tokenizer>,
68 no_kv_cache: bool,
69 chat_template: Arc<ChatTemplate>,
70 non_granular_state: Option<NonGranularState>,
71 model_id: String,
72 metadata: Arc<GeneralMetadata>,
73 topology: Option<Topology>,
74 silent: bool,
75 organization: IsqOrganization,
76 template_filename: Option<PathBuf>,
78 generation_config: Option<PathBuf>,
79 config: String,
80 imatrix: Option<PathBuf>,
81 mapper: Box<dyn DeviceMapper + Send + Sync>,
82}
83
84pub struct NormalLoader {
86 inner: Box<dyn NormalModelLoader>,
87 model_id: String,
88 config: NormalSpecificConfig,
89 xlora_model_id: Option<String>,
90 lora_adapter_ids: Option<Vec<String>>,
91 kind: ModelKind,
92 xlora_order: Option<Ordering>,
93 no_kv_cache: bool,
94 chat_template: Option<String>,
95 tokenizer_json: Option<String>,
96 tgt_non_granular_index: Option<usize>,
97 token_source: RwLock<Option<TokenSource>>,
98 revision: RwLock<Option<String>>,
99 from_uqff: RwLock<Option<Vec<PathBuf>>>,
100 jinja_explicit: Option<String>,
101 hf_cache_path: Option<PathBuf>,
102}
103
104#[derive(Default)]
105pub struct NormalLoaderBuilder {
107 model_id: Option<String>,
108 config: NormalSpecificConfig,
109 xlora_model_id: Option<String>,
110 lora_adapter_ids: Option<Vec<String>>,
111 kind: ModelKind,
112 xlora_order: Option<Ordering>,
113 no_kv_cache: bool,
114 chat_template: Option<String>,
115 tokenizer_json: Option<String>,
116 tgt_non_granular_index: Option<usize>,
117 jinja_explicit: Option<String>,
118 hf_cache_path: Option<PathBuf>,
119}
120
121#[derive(Clone, Default)]
122pub struct NormalSpecificConfig {
124 pub topology: Option<Topology>,
125 pub organization: IsqOrganization,
126 pub write_uqff: Option<PathBuf>,
127 pub from_uqff: Option<Vec<PathBuf>>,
128 pub imatrix: Option<PathBuf>,
129 pub calibration_file: Option<PathBuf>,
130 pub hf_cache_path: Option<PathBuf>,
131 pub matformer_config_path: Option<PathBuf>,
132 pub matformer_slice_name: Option<String>,
133}
134
135impl NormalLoaderBuilder {
136 pub fn new(
137 config: NormalSpecificConfig,
138 chat_template: Option<String>,
139 tokenizer_json: Option<String>,
140 model_id: Option<String>,
141 no_kv_cache: bool,
142 jinja_explicit: Option<String>,
143 ) -> Self {
144 Self {
145 config,
146 chat_template,
147 tokenizer_json,
148 model_id,
149 kind: ModelKind::Normal,
150 jinja_explicit,
151 no_kv_cache,
152 ..Default::default()
153 }
154 }
155
156 fn with_adapter(
157 mut self,
158 xlora_model_id: String,
159 xlora_order: Ordering,
160 no_kv_cache: bool,
161 tgt_non_granular_index: Option<usize>,
162 ) -> Self {
163 self.xlora_model_id = Some(xlora_model_id);
164 self.xlora_order = Some(xlora_order);
165 self.no_kv_cache = no_kv_cache;
166 self.tgt_non_granular_index = tgt_non_granular_index;
167 self.model_id = if let Some(id) = self.model_id {
168 Some(id)
169 } else {
170 info!(
171 "Using adapter base model ID: `{}`",
172 self.xlora_order.as_ref().unwrap().base_model_id
173 );
174 Some(self.xlora_order.as_ref().unwrap().base_model_id.clone())
175 };
176 self
177 }
178
179 pub fn with_xlora(
180 mut self,
181 xlora_model_id: String,
182 xlora_order: Ordering,
183 no_kv_cache: bool,
184 tgt_non_granular_index: Option<usize>,
185 ) -> Self {
186 self.kind = ModelKind::Adapter {
187 adapter: AdapterKind::XLora,
188 };
189 self.with_adapter(
190 xlora_model_id,
191 xlora_order,
192 no_kv_cache,
193 tgt_non_granular_index,
194 )
195 }
196
197 pub fn with_lora(mut self, lora_adapter_ids: Vec<String>) -> Self {
198 self.kind = ModelKind::Adapter {
199 adapter: AdapterKind::Lora,
200 };
201 self.lora_adapter_ids = Some(lora_adapter_ids);
202 self
203 }
204
205 pub fn hf_cache_path(mut self, hf_cache_path: PathBuf) -> Self {
206 self.hf_cache_path = Some(hf_cache_path);
207 self
208 }
209
210 pub fn build(self, loader_tp: Option<NormalLoaderType>) -> anyhow::Result<Box<dyn Loader>> {
213 let loader: Box<dyn NormalModelLoader> = match loader_tp {
214 Some(NormalLoaderType::Mistral) => Box::new(MistralLoader),
215 Some(NormalLoaderType::Gemma) => Box::new(GemmaLoader),
216 Some(NormalLoaderType::Llama) => Box::new(LlamaLoader),
217 Some(NormalLoaderType::Mixtral) => Box::new(MixtralLoader),
218 Some(NormalLoaderType::Phi2) => Box::new(Phi2Loader),
219 Some(NormalLoaderType::Phi3) => Box::new(Phi3Loader),
220 Some(NormalLoaderType::Qwen2) => Box::new(Qwen2Loader),
221 Some(NormalLoaderType::Gemma2) => Box::new(Gemma2Loader),
222 Some(NormalLoaderType::Starcoder2) => Box::new(Starcoder2Loader),
223 Some(NormalLoaderType::Phi3_5MoE) => Box::new(Phi3_5MoELoader),
224 Some(NormalLoaderType::DeepSeekV2) => Box::new(DeepSeekV2Loader),
225 Some(NormalLoaderType::DeepSeekV3) => Box::new(DeepSeekV3Loader),
226 Some(NormalLoaderType::Qwen3) => Box::new(Qwen3Loader),
227 Some(NormalLoaderType::GLM4) => Box::new(GLM4Loader),
228 Some(NormalLoaderType::Qwen3Moe) => Box::new(Qwen3MoELoader),
229 Some(NormalLoaderType::SmolLm3) => Box::new(SmolLm3Loader),
230 None => Box::new(AutoNormalLoader),
231 };
232 Ok(Box::new(NormalLoader {
233 inner: loader,
234 model_id: self.model_id.unwrap(),
235 config: self.config,
236 xlora_model_id: self.xlora_model_id,
237 lora_adapter_ids: self.lora_adapter_ids,
238 kind: self.kind,
239 xlora_order: self.xlora_order,
240 no_kv_cache: self.no_kv_cache,
241 chat_template: self.chat_template,
242 tokenizer_json: self.tokenizer_json,
243 tgt_non_granular_index: self.tgt_non_granular_index,
244 jinja_explicit: self.jinja_explicit,
245 token_source: RwLock::new(None),
246 revision: RwLock::new(None),
247 from_uqff: RwLock::new(None),
248 hf_cache_path: self.hf_cache_path,
249 }))
250 }
251}
252
253impl Loader for NormalLoader {
254 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
255 fn load_model_from_hf(
256 &self,
257 revision: Option<String>,
258 token_source: TokenSource,
259 dtype: &dyn TryIntoDType,
260 device: &Device,
261 silent: bool,
262 mapper: DeviceMapSetting,
263 in_situ_quant: Option<IsqType>,
264 paged_attn_config: Option<PagedAttentionConfig>,
265 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
266 let cache = self
267 .hf_cache_path
268 .clone()
269 .map(Cache::new)
270 .unwrap_or_default();
271 GLOBAL_HF_CACHE.get_or_init(|| cache);
272
273 let paths: anyhow::Result<Box<dyn ModelPaths>> = get_paths!(
274 LocalModelPaths,
275 &token_source,
276 revision.clone(),
277 self,
278 None,
279 None,
280 silent,
281 self.config.from_uqff.is_some()
282 );
283 if let Some(from_uqff) = self.config.from_uqff.clone() {
284 *self.from_uqff.write().unwrap() = Some(get_uqff_paths!(&from_uqff, self, silent));
285 }
286 *self
287 .token_source
288 .write()
289 .expect("Failed to write to token source") = Some(token_source);
290 *self.revision.write().expect("Failed to write to revision") = revision;
291 self.load_model_from_path(
292 &paths?,
293 dtype,
294 device,
295 silent,
296 mapper,
297 in_situ_quant,
298 paged_attn_config,
299 )
300 }
301
302 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
303 fn load_model_from_path(
304 &self,
305 paths: &Box<dyn ModelPaths>,
306 dtype: &dyn TryIntoDType,
307 device: &Device,
308 silent: bool,
309 mut mapper: DeviceMapSetting,
310 mut in_situ_quant: Option<IsqType>,
311 mut paged_attn_config: Option<PagedAttentionConfig>,
312 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
313 let config = std::fs::read_to_string(paths.get_config_filename())?;
314
315 if !self.inner.supports_paged_attention(&config)? {
316 paged_attn_config = None;
317 }
318
319 info!("Prompt chunk size is {ATTENTION_CHUNK_SIZE}.");
320
321 let use_nccl = mistralrs_quant::distributed::use_nccl();
322
323 let available_devices = if let Ok(payload) = env::var(distributed::IS_DAEMON_FLAG) {
324 let payload: WorkerTransferData = serde_json::from_str(&payload)?;
325 let WorkerTransferData::Init { id: _, worker_rank } = payload;
326 vec![candle_core::Device::new_cuda(worker_rank + 1)?]
327 } else if use_nccl {
328 vec![candle_core::Device::new_cuda(0)?]
329 } else {
330 device_map::get_all_similar_devices(device)?
331 };
332 #[cfg(feature = "cuda")]
333 for device in &available_devices {
334 if let Device::Cuda(dev) = device {
335 unsafe { dev.disable_event_tracking() };
336 }
337 }
338 let device = if use_nccl || cfg!(feature = "ring") {
339 available_devices[0].clone()
340 } else {
341 device.clone()
342 };
343
344 if use_nccl || cfg!(feature = "ring") {
346 mapper = DeviceMapSetting::DummyNccl {
347 nm_device: available_devices[0].clone(),
348 };
349 } else if let DeviceMapSetting::Auto(params) = mapper.clone() {
350 let dtype = dtype.try_into_dtype(&available_devices.iter().collect::<Vec<_>>())?;
352
353 if QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)? != 1 {
355 in_situ_quant = None;
356 }
357
358 let (layer_sizes_in_bytes, non_mapped_size_in_bytes, total_model_size_in_bytes) =
361 if let Some(serialized) = &*self.from_uqff.read().unwrap() {
362 let weight_pack_factor = {
363 let ser_artifacts = unsafe {
364 candle_core::safetensors::MmapedSafetensors::multi(serialized)?
365 };
366 let mut total_pack_factors = 0;
367 let total_tensors = ser_artifacts.tensors().len();
368 for (_, artifact) in ser_artifacts.tensors() {
369 let artifact = artifact.data();
370 let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
372 let pack_factor = match QuantizedSerdeType::try_from(isq_type as usize)?
373 {
374 QuantizedSerdeType::Hqq => {
375 HqqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
376 .pack_factor(dtype)
377 }
378 QuantizedSerdeType::Gguf => {
379 GgufMatMul::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
380 .pack_factor(dtype)
381 }
382 QuantizedSerdeType::Fp8 => IsqType::F8E4M3.pack_factor(dtype),
383 QuantizedSerdeType::Unquant => 1,
384 QuantizedSerdeType::Afq => {
385 AfqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
386 .pack_factor(dtype)
387 }
388 };
389 total_pack_factors += pack_factor;
390 }
391
392 total_pack_factors / total_tensors
393 };
394
395 let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
396 &config,
397 dtype,
398 weight_pack_factor,
399 None,
400 )?;
401 let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
402 &config,
403 dtype,
404 weight_pack_factor,
405 None,
406 )?;
407 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
408 (
409 layer_sizes_in_bytes,
410 non_mapped_size_in_bytes,
411 layer_sizes_sum + non_mapped_size_in_bytes,
412 )
413 } else if let Some(isq) = in_situ_quant {
414 let weight_pack_factor = isq.pack_factor(dtype);
415 let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
416 &config,
417 dtype,
418 weight_pack_factor,
419 None,
420 )?;
421 let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
422 &config,
423 dtype,
424 weight_pack_factor,
425 None,
426 )?;
427 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
428 (
429 layer_sizes_in_bytes,
430 non_mapped_size_in_bytes,
431 layer_sizes_sum + non_mapped_size_in_bytes,
432 )
433 } else {
434 let weight_pack_factor =
436 QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)?;
437 let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
438 &config,
439 dtype,
440 weight_pack_factor,
441 None,
442 )?;
443 let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
444 &config,
445 dtype,
446 weight_pack_factor,
447 None,
448 )?;
449 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
450 (
451 layer_sizes_in_bytes,
452 non_mapped_size_in_bytes,
453 layer_sizes_sum + non_mapped_size_in_bytes,
454 )
455 };
456
457 let new = auto_device_map::get_device_layers(
458 &*self.inner,
459 &config,
460 self.inner.num_layers(&config)?,
461 layer_sizes_in_bytes,
462 non_mapped_size_in_bytes,
463 total_model_size_in_bytes,
464 &available_devices,
465 dtype,
466 ¶ms,
467 paged_attn_config.as_ref(),
468 )?;
469 mapper = DeviceMapSetting::Map(new);
470 }
471
472 let pipeline_mapper = mapper.into_mapper(
473 self.inner.num_layers(&config)?,
474 &device,
475 self.config.topology.as_ref(),
476 )?;
477 let mapper = mapper.into_mapper(
478 self.inner.num_layers(&config)?,
479 &device,
480 self.config.topology.as_ref(),
481 )?;
482 let mut layer_devices = Vec::new();
483 for layer in 0..self.inner.num_layers(&config)? {
484 let device = mapper.device_for(layer, false).cloned();
485 layer_devices.push(device);
486 }
487 let dtype = mapper.get_min_dtype(dtype)?;
488
489 let mapping_uses_cpu = mapper.get_unique_devices().iter().any(Device::is_cpu);
492 if mapping_uses_cpu && paged_attn_config.is_some() {
493 warn!("Device mapping contains a mix of GPU and CPU. There is no CPU support for PagedAttention, disabling PagedAttention.");
494 paged_attn_config = None;
495 }
496
497 info!("Model config: {:?}", self.inner.get_config_repr(&config)?);
498 if crate::using_flash_attn() {
499 once_log_info("FlashAttention is enabled.");
500 }
501
502 let topology_overrides = self
503 .config
504 .topology
505 .as_ref()
506 .map(|topology| {
507 topology
508 .pattern_overrides()
509 .into_iter()
510 .map(|(regex, layer)| ImmediateIsqOverride {
511 predicate: regex,
512 ty: layer.isq,
513 device: layer.device.clone(),
514 })
515 .collect::<Vec<_>>()
516 })
517 .unwrap_or_default();
518 let has_override_isq = topology_overrides
519 .iter()
520 .any(|override_entry| override_entry.ty.is_some());
521 let topology_requires_post_quant = self
522 .config
523 .topology
524 .as_ref()
525 .is_some_and(|topology| topology.requires_post_quantization());
526
527 let allow_immediate_cli = self.config.imatrix.is_none()
528 && self.config.calibration_file.is_none()
529 && !device.is_cuda()
530 && in_situ_quant.is_some();
531
532 let mut immediate_ty = None;
533 let mut immediate_predicates = Vec::new();
534 if allow_immediate_cli {
535 immediate_ty = in_situ_quant;
536 immediate_predicates =
537 if matches!(self.config.organization, IsqOrganization::MoeExpertsOnly) {
538 self.inner.immediate_isq_predicates_moqe(&config)?
539 } else {
540 self.inner.immediate_isq_predicates(&config)?
541 };
542 info!("Applying ISQ to {in_situ_quant:?}");
543 if immediate_predicates.is_empty() {
544 warn!("No predicates for this model and ISQ setting detected. ISQ will not be applied to any weights!");
545 }
546 }
547
548 let use_immediate = allow_immediate_cli || has_override_isq;
549 if use_immediate {
550 mistralrs_quant::set_immediate_isq_with_overrides(
551 immediate_ty,
552 immediate_predicates.clone(),
553 topology_overrides.clone(),
554 );
555 }
556
557 let mut loading_isq = if use_immediate {
559 false
560 } else {
561 in_situ_quant.is_some()
562 };
563 if self.config.imatrix.is_some() || self.config.calibration_file.is_some() {
564 loading_isq = true;
565 }
566 loading_isq |= topology_requires_post_quant;
567
568 if self.config.imatrix.is_some() && self.config.calibration_file.is_some() {
569 anyhow::bail!(
570 "`imatrix` and `calibration_file` were both specified, this is not allowed."
571 );
572 }
573
574 let load_device = if !loading_isq || self.config.calibration_file.is_some() {
576 loading_isq = false;
577 device.clone()
578 } else {
579 Device::Cpu
580 };
581
582 let is_xlora = self.kind.is_adapted_and(|a| a.is_x_lora());
583
584 let attention_mechanism = if paged_attn_config.is_some() {
585 AttentionImplementation::PagedAttention
586 } else {
587 AttentionImplementation::Eager
588 };
589
590 let multi_progress = Arc::new(MultiProgress::new());
591
592 let matformer_slicing_config = if let Some(matformer_path) =
594 &self.config.matformer_config_path
595 {
596 use crate::matformer::{MatformerConfig, MatformerSliceConfig};
597 info!("Loading Matformer config from {:?}", matformer_path);
598 let config = Arc::new(MatformerConfig::from_file(matformer_path)?);
599
600 if let Some(slice_name) = &self.config.matformer_slice_name {
601 info!("Using Matformer slice: {}", slice_name);
602 Some(MatformerSliceConfig::new(slice_name.clone(), config))
603 } else {
604 warn!("Matformer config loaded but no slice name specified. Models will use their default slice.");
607 None
608 }
609 } else {
610 None
611 };
612
613 let mut model = if use_nccl || cfg!(feature = "ring") {
614 let (mapper, sharded_vb) = distributed::prepare_distributed_mapper(
615 dtype,
616 &device,
617 &available_devices,
618 silent,
619 &config,
620 loading_isq,
621 self.config.from_uqff.is_some(),
622 self.config.organization,
623 &*self.inner,
624 paths.as_ref(),
625 )?;
626
627 match self.kind {
629 ModelKind::Normal => normal_model_loader_sharded!(
630 sharded_vb,
631 config,
632 self.inner,
633 mapper,
634 loading_isq,
635 device.clone(),
636 attention_mechanism,
637 multi_progress.clone(),
638 matformer_slicing_config.clone(),
639 ),
640 ModelKind::Adapter {
641 adapter: AdapterKind::XLora,
642 } => xlora_model_loader!(
643 paths,
644 Some(dtype),
645 &load_device,
646 layer_devices.clone(),
647 config,
648 self.inner,
649 silent,
650 mapper,
651 loading_isq,
652 device.clone(),
653 multi_progress.clone(),
654 matformer_slicing_config.clone(),
655 ),
656 ModelKind::Adapter {
657 adapter: AdapterKind::Lora,
658 } => lora_model_loader!(
659 paths,
660 Some(dtype),
661 &load_device,
662 layer_devices.clone(),
663 config,
664 self.inner,
665 silent,
666 mapper,
667 loading_isq,
668 self.config.from_uqff.is_some(),
669 device.clone(),
670 attention_mechanism,
671 matches!(self.config.organization, IsqOrganization::MoeExpertsOnly),
672 multi_progress.clone(),
673 matformer_slicing_config.clone(),
674 ),
675 _ => unreachable!(),
676 }
677 } else {
678 match self.kind {
679 ModelKind::Normal => normal_model_loader!(
680 paths,
681 Some(dtype),
682 &load_device,
683 layer_devices.clone(),
684 config,
685 self.inner,
686 silent,
687 mapper,
688 loading_isq,
689 self.config.from_uqff.is_some(),
690 device.clone(),
691 attention_mechanism,
692 matches!(self.config.organization, IsqOrganization::MoeExpertsOnly),
693 multi_progress.clone(),
694 matformer_slicing_config.clone(),
695 ),
696 ModelKind::Adapter {
697 adapter: AdapterKind::XLora,
698 } => xlora_model_loader!(
699 paths,
700 Some(dtype),
701 &load_device,
702 layer_devices.clone(),
703 config,
704 self.inner,
705 silent,
706 mapper,
707 loading_isq,
708 device.clone(),
709 multi_progress.clone(),
710 matformer_slicing_config.clone(),
711 ),
712 ModelKind::Adapter {
713 adapter: AdapterKind::Lora,
714 } => lora_model_loader!(
715 paths,
716 Some(dtype),
717 &load_device,
718 layer_devices.clone(),
719 config,
720 self.inner,
721 silent,
722 mapper,
723 loading_isq,
724 self.config.from_uqff.is_some(),
725 device.clone(),
726 attention_mechanism,
727 matches!(self.config.organization, IsqOrganization::MoeExpertsOnly),
728 multi_progress.clone(),
729 matformer_slicing_config.clone(),
730 ),
731 _ => unreachable!(),
732 }
733 };
734
735 let tokenizer = get_tokenizer(paths.get_tokenizer_filename(), None)?;
736 let gen_conf: Option<GenerationConfig> = paths.get_gen_conf_filename().and_then(|f| {
737 match serde_json::from_str::<GenerationConfig>(&fs::read_to_string(f).unwrap()) {
738 Ok(conf) => Some(conf),
739 Err(e) => {
740 warn!("Failed to parse generation_config.json: {}", e);
741 None
742 }
743 }
744 });
745
746 let chat_template_explicit = paths
747 .get_chat_template_explicit()
748 .as_ref()
749 .map(|x| x.to_string_lossy().to_string());
750 let chat_template = get_chat_template(
751 paths,
752 self.jinja_explicit.as_ref(),
753 chat_template_explicit.as_ref(),
754 self.chat_template.as_ref(),
755 None,
756 );
757
758 if let Some(calibration_file) = &self.config.calibration_file {
759 let calibration_data = std::fs::read_to_string(calibration_file)?;
760 let tokens = tokenizer
762 .encode_fast(calibration_data, false)
763 .map_err(anyhow::Error::msg)?
764 .get_ids()
765 .to_vec();
766 info!(
767 "Collecting imatrix from calibration file `{}` of {} tokens.",
768 calibration_file.display(),
769 tokens.len()
770 );
771 let bos_toks = chat_template.bos_tok().map(|b| vec![b]).unwrap_or_default();
772 let bos_tok_id = tokenizer
773 .token_to_id(&bos_toks[0])
774 .expect("Somehow the bos token is not present.");
775
776 match self.config.organization {
777 IsqOrganization::Default => model.begin_track_stats()?,
778 IsqOrganization::MoeExpertsOnly => model.begin_track_stats_moe_experts_only()?,
779 }
780
781 const CHUNK_SIZE: usize = 1024;
782 let n_chunks = tokens.len().div_ceil(CHUNK_SIZE);
783 let start = Instant::now();
784 for (i, chunk) in tokens.chunks(CHUNK_SIZE).enumerate() {
785 let chunk = [vec![bos_tok_id], chunk.to_vec()].concat();
786 let chunk_len = chunk.len();
787
788 let start = Instant::now();
789 let inputs = make_prompt_chunk(
790 0,
791 vec![&chunk],
792 &[0],
793 &load_device,
794 None,
795 false,
796 None,
797 Some(pipeline_mapper.as_ref()),
798 )?;
799
800 model.forward(
801 &inputs.input.to_device(model.device())?,
802 &inputs.positions,
803 inputs.context_lens.clone(),
804 inputs.position_ids.clone(),
805 None,
806 &inputs.flash_meta.clone(),
807 )?;
808
809 match model.cache_mut() {
810 EitherCache::Full(full) => {
811 for layer in &mut *full.lock() {
812 *layer = None
813 }
814 }
815 EitherCache::Normal(normal) => {
816 for layer in &mut *normal.lock().unwrap().0 {
817 layer.reset();
818 }
819 }
820 }
821
822 let end = Instant::now();
823 info!(
824 "Processed chunk {}/{n_chunks} ({chunk_len} tokens), {:.2}s",
825 i + 1,
826 end.duration_since(start).as_secs_f32()
827 );
828 }
829 load_device.synchronize()?;
830 let end = Instant::now();
831 info!(
832 "Finished collecting imatrix in {:.2}s",
833 end.duration_since(start).as_secs_f32()
834 );
835 }
836
837 let should_serialize = self.config.write_uqff.is_some();
839 let should_quantize_pass = loading_isq;
840
841 if (should_quantize_pass || should_serialize) && self.config.from_uqff.is_none() {
842 let imatrix_source = if should_quantize_pass {
843 match (
844 self.config.imatrix.as_ref(),
845 self.config.calibration_file.is_some(),
846 ) {
847 (None, false) => None,
848 (Some(file), false) => Some(ImatrixDataSource::File(file)),
849 (None, true) => Some(ImatrixDataSource::Collected),
850 (Some(_), true) => unreachable!(),
851 }
852 } else {
853 None
854 };
855
856 if should_quantize_pass {
857 info!("Applying ISQ to all ranks.");
858 } else {
859 info!("Serializing existing ISQ tensors without additional quantization.");
860 }
861
862 let multi_progress = Arc::new(MultiProgress::new());
863
864 model.quantize(
865 in_situ_quant,
866 model.device().clone(),
867 self.config.topology.as_ref(),
868 silent,
869 imatrix_source,
870 self.config.organization,
871 should_quantize_pass,
872 self.config.write_uqff.as_ref(),
873 UqffFullSer {
874 tokenizer: &tokenizer,
875 template_filename: paths.get_template_filename(),
876 generation_config: paths.get_gen_conf_filename(),
877 config: config.clone(),
878 processor_filename: &None,
879 preprocessor_filename: &None,
880 modules: None,
881 module_paths: None,
882 },
883 multi_progress.clone(),
884 )?;
885 } else if let Some(from_uqff) = &*self.from_uqff.read().unwrap() {
886 model.load_from_artifacts(
887 device.clone(),
888 self.config.topology.as_ref(),
889 silent,
890 from_uqff,
891 )?;
892 }
893
894 let paged_attn_config = if matches!(
895 self.kind,
896 ModelKind::Adapter {
897 adapter: AdapterKind::XLora
898 }
899 ) {
900 warn!(
901 "Adapter parallel_models do not currently support PagedAttention, running without"
902 );
903 None
904 } else {
905 paged_attn_config
906 };
907
908 let (cache_config, cache_engine) = if let Some(paged_attn_config) = paged_attn_config {
909 let cache_config = calculate_cache_config(
910 paged_attn_config.mem_gpu,
911 paged_attn_config.block_size,
912 dtype,
913 paged_attn_config.cache_type,
914 model.config(),
915 &device,
916 &pipeline_mapper
917 .get_unique_devices()
918 .into_iter()
919 .map(Some)
920 .collect::<Vec<_>>(),
921 silent,
922 )?;
923
924 let mut layer_devices = Vec::new();
925 for layer in 0..self.inner.num_layers(&config)? {
926 let device = model.get_layers().1.device_for(layer, false).cloned();
927 layer_devices.push(device);
928 }
929 let cache_engine = CacheEngine::new(
930 model.config(),
931 &cache_config,
932 dtype,
933 model.device(),
934 layer_devices.clone(),
935 )?;
936
937 (Some(cache_config), Some(cache_engine))
938 } else {
939 (None, None)
940 };
941
942 let max_seq_len = model.max_seq_len();
943 let llg_factory = build_llg_factory(tokenizer.clone())?;
944 let num_hidden_layers = match model.cache() {
945 EitherCache::Full(full) => full.lock().len(),
946 EitherCache::Normal(normal) => normal.lock().unwrap().0.len(),
947 };
948 let eos = calculate_eos_tokens(&chat_template, gen_conf, &tokenizer);
949 let sliding_window = model.config().sliding_window;
950 let model_metadata = Arc::new(model.config().clone());
951
952 Ok(Arc::new(Mutex::new(NormalPipeline {
953 model,
954 tokenizer: tokenizer.into(),
955 no_kv_cache: self.no_kv_cache,
956 chat_template: Arc::new(chat_template),
957 non_granular_state: self.tgt_non_granular_index.map(|tgt_non_granular_index| {
958 NonGranularState {
959 non_granular_index: Arc::new(Mutex::new(0)),
960 tgt_non_granular_index,
961 }
962 }),
963 model_id: self.model_id.clone(),
964 metadata: Arc::new(GeneralMetadata {
965 max_seq_len,
966 llg_factory: Some(llg_factory),
967 no_kv_cache: self.no_kv_cache,
968 no_prefix_cache: is_xlora,
969 num_hidden_layers,
970 eos_tok: eos,
971 kind: self.kind.clone(),
972 is_xlora,
973 activation_dtype: dtype,
974 sliding_window,
975 cache_config,
976 cache_engine,
977 model_metadata: Some(model_metadata),
978 modalities: Modalities {
979 input: vec![SupportedModality::Text],
980 output: vec![SupportedModality::Text],
981 },
982 }),
983 topology: self.config.topology.clone(),
984 silent,
985 organization: self.config.organization,
986 template_filename: paths.get_template_filename().clone(),
987 generation_config: paths.get_gen_conf_filename().cloned(),
988 config,
989 imatrix: self.config.imatrix.clone(),
990 mapper: pipeline_mapper,
991 })))
992 }
993
994 fn get_id(&self) -> String {
995 self.model_id.clone()
996 }
997
998 fn get_kind(&self) -> ModelKind {
999 self.kind.clone()
1000 }
1001}
1002
1003impl PreProcessingMixin for NormalPipeline {
1004 fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
1005 Some(self.chat_template.clone())
1006 }
1007 fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
1008 None
1009 }
1010}
1011
1012impl IsqPipelineMixin for NormalPipeline {
1013 fn re_isq_model(&mut self, dtype: IsqType) -> Result<()> {
1014 let device = self.device().clone();
1015 let multi_progress = Arc::new(MultiProgress::new());
1016 self.model.quantize(
1017 Some(dtype),
1018 device.clone(),
1019 self.topology.as_ref(),
1020 self.silent,
1021 self.imatrix.as_ref().map(ImatrixDataSource::File),
1022 self.organization,
1023 true,
1024 None,
1025 UqffFullSer {
1026 tokenizer: &self.tokenizer,
1027 template_filename: &self.template_filename,
1028 generation_config: self.generation_config.as_ref(),
1029 config: self.config.clone(),
1030 processor_filename: &None,
1031 preprocessor_filename: &None,
1032 modules: None,
1033 module_paths: None,
1034 },
1035 multi_progress.clone(),
1036 )?;
1037 Ok(())
1038 }
1039}
1040
1041impl CacheManagerMixin for NormalPipeline {
1042 fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
1043 if matches!(self.model.cache(), EitherCache::Full(_)) {
1044 FullCacheManager.clone_in_cache(self, seqs, false)
1045 } else {
1046 NormalCacheManager.clone_in_cache(self, seqs, false)
1047 }
1048 }
1049 fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
1050 if matches!(self.model.cache(), EitherCache::Full(_)) {
1051 FullCacheManager.clone_out_cache(self, seqs, false)
1052 } else {
1053 NormalCacheManager.clone_out_cache(self, seqs, false)
1054 }
1055 }
1056 fn set_none_cache(
1057 &self,
1058 seqs: &mut [&mut Sequence],
1059 reset_non_granular: bool,
1060 modify_draft_cache: bool,
1061 load_preallocated_cache: bool,
1062 ) {
1063 if matches!(self.model.cache(), EitherCache::Full(_)) {
1064 FullCacheManager.set_none_cache(self, seqs, modify_draft_cache, false);
1065 } else {
1066 NormalCacheManager.set_none_cache(
1067 self,
1068 seqs,
1069 modify_draft_cache,
1070 load_preallocated_cache,
1071 );
1072 }
1073 if reset_non_granular {
1074 self.reset_non_granular_state()
1075 }
1076 }
1077 fn cache(&self) -> &EitherCache {
1078 self.model.cache()
1079 }
1080}
1081
1082impl MetadataMixin for NormalPipeline {
1083 fn device(&self) -> Device {
1084 self.model.device().clone()
1085 }
1086 fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
1087 Some(self.tokenizer.clone())
1088 }
1089 fn name(&self) -> String {
1090 self.model_id.clone()
1091 }
1092 fn reset_non_granular_state(&self) {
1093 if let Some(s) = self.non_granular_state.as_ref() {
1094 *self.cache().full().get_scalings_cache() = None;
1095 *get_mut_arcmutex!(s.non_granular_index) = 0;
1096 }
1097 }
1098 fn get_metadata(&self) -> Arc<GeneralMetadata> {
1099 self.metadata.clone()
1100 }
1101 fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
1102 Some(&*self.mapper)
1103 }
1104}
1105
1106#[async_trait::async_trait]
1107impl Pipeline for NormalPipeline {
1108 fn forward_inputs(
1109 &mut self,
1110 inputs: Box<dyn Any>,
1111 return_raw_logits: bool,
1112 ) -> Result<ForwardInputsResult, candle_core::Error> {
1113 let ModelInputs {
1114 input_ids,
1115 input_ids_full,
1116 seqlen_offsets,
1117 seqlen_offsets_full,
1118 context_lens,
1119 position_ids,
1120 paged_attn_meta,
1121 flash_meta,
1122 flash_meta_full,
1123 } = *inputs.downcast().expect("Downcast failed.");
1124 let metadata = self.get_metadata();
1125 let paged_attn_meta = match (&metadata.cache_engine, &paged_attn_meta) {
1126 (Some(cache_engine), Some(meta)) => Some((cache_engine, meta)),
1127 (Some(_), None) => {
1128 candle_core::bail!("Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.")
1130 }
1131 (None, Some(_)) => {
1132 candle_core::bail!("Forward step got a PagedAttention input metadata but there is no cache engine. Please raise an issue.")
1134 }
1135 (None, None) => None,
1136 };
1137 let logits = match self.model.is_xlora() {
1138 false => {
1139 let paged_attn_meta = paged_attn_meta
1140 .as_ref()
1141 .map(|meta| (meta.0.get_kv_cache().clone(), meta.1.clone()));
1142
1143 self.model.forward(
1144 &input_ids,
1145 &seqlen_offsets,
1146 context_lens,
1147 position_ids,
1148 paged_attn_meta.as_ref().map(|(a, b)| (a.clone(), b)),
1149 &flash_meta,
1150 )?
1151 }
1152 true => self.model.xlora_forward(
1153 &input_ids,
1154 input_ids_full.as_ref().unwrap_or(&input_ids),
1155 &seqlen_offsets,
1156 seqlen_offsets_full.as_ref().unwrap_or(&seqlen_offsets),
1157 self.no_kv_cache,
1158 &self.non_granular_state,
1159 context_lens,
1160 position_ids,
1161 &flash_meta,
1162 flash_meta_full.as_ref().unwrap_or(&flash_meta),
1163 )?,
1164 };
1165 if return_raw_logits {
1166 Ok(ForwardInputsResult::RawLogits { logits })
1167 } else {
1168 Ok(ForwardInputsResult::CausalGeneration { logits })
1169 }
1170 }
1171 async fn sample_causal_gen(
1172 &self,
1173 seqs: &mut [&mut Sequence],
1174 logits: Vec<Tensor>,
1175 prefix_cacher: &mut PrefixCacheManagerV2,
1176 disable_eos_stop: bool,
1177 rng: Arc<std::sync::Mutex<Isaac64Rng>>,
1178 ) -> Result<(), candle_core::Error> {
1179 sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
1180 }
1181 fn category(&self) -> ModelCategory {
1182 ModelCategory::Text
1183 }
1184}
1185
1186impl AnyMoePipelineMixin for NormalPipeline {
1187 fn amoe_finish_training(&mut self, gate_model_id: Option<String>) -> candle_core::Result<()> {
1188 self.model.finish_training(gate_model_id)
1189 }
1190 fn amoe_layer_vars(&self) -> Vec<Vec<Var>> {
1191 self.model.get_vars()
1192 }
1193 fn amoe_base_model_trainable_params(&self) -> usize {
1194 self.model.trainable_params()
1195 }
1196 fn amoe_take_cached_gating_outputs(&mut self) -> Vec<Tensor> {
1197 self.model.take_cached_gating_outputs()
1198 }
1199 fn amoe_create_layers(
1200 &mut self,
1201 model_ids: Vec<String>,
1202 token: &TokenSource,
1203 revision: Option<String>,
1204 match_regex: &str,
1205 config: crate::amoe::AnyMoeConfig,
1206 dtype: candle_core::DType,
1207 dev: &Device,
1208 (prefix, mlp): (String, String),
1209 layers: Vec<usize>,
1210 expert_type: AnyMoeExpertType,
1211 silent: bool,
1212 gate_model_id: Option<String>,
1213 ) -> candle_core::Result<()> {
1214 let mut vbs = Vec::new();
1215 let regex = Regex::new(match_regex).map_err(candle_core::Error::msg)?;
1217 for model_id in model_ids {
1218 let model_id_str = &model_id;
1219 let model_id = Path::new(&model_id);
1220
1221 let api = {
1222 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1223 let mut api = ApiBuilder::from_cache(cache)
1224 .with_progress(!silent)
1225 .with_token(get_token(token).map_err(candle_core::Error::msg)?);
1226 if let Ok(x) = std::env::var("HF_HUB_CACHE") {
1227 api = api.with_cache_dir(x.into());
1228 }
1229 api.build().map_err(candle_core::Error::msg)?
1230 };
1231 let revision = revision.clone().unwrap_or("main".to_string());
1232 let api = api.repo(Repo::with_revision(
1233 model_id_str.clone(),
1234 RepoType::Model,
1235 revision.clone(),
1236 ));
1237
1238 let mut filenames = vec![];
1239 for rfilename in
1240 api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors"))
1241 {
1242 filenames.push(api_get_file!(api, &rfilename, model_id));
1243 }
1244
1245 let regex = regex.clone();
1246 let match_regex_clone = match_regex.to_string();
1247 let layers_clone = layers.clone();
1248 let vb = from_mmaped_safetensors(
1249 filenames,
1250 vec![],
1251 Some(dtype),
1252 dev,
1253 vec![None],
1254 silent,
1255 None,
1256 move |key| {
1257 if regex.is_match(&key) {
1258 let last_layer_idx = key.find(&match_regex_clone).unwrap() - 1;
1261 let first_layer_idx = key[..last_layer_idx].rfind('.').unwrap();
1262 let layer_n = key[first_layer_idx + 1..last_layer_idx]
1263 .parse::<usize>()
1264 .unwrap();
1265 layers_clone.contains(&layer_n) || layers_clone.is_empty()
1266 } else {
1267 false
1268 }
1269 },
1270 Arc::new(|_| DeviceForLoadTensor::Base),
1271 )?;
1272 vbs.push(vb);
1273 }
1274
1275 let gate_vb = if let Some(gate_model_id) = gate_model_id {
1276 let model_id_str = &gate_model_id;
1277 let model_id = Path::new(&gate_model_id);
1278
1279 let api = {
1280 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1281 let mut api = ApiBuilder::from_cache(cache)
1282 .with_progress(!silent)
1283 .with_token(get_token(token).map_err(candle_core::Error::msg)?);
1284 if let Ok(x) = std::env::var("HF_HUB_CACHE") {
1285 api = api.with_cache_dir(x.into());
1286 }
1287 api.build().map_err(candle_core::Error::msg)?
1288 };
1289 let revision = revision.clone().unwrap_or("main".to_string());
1290 let api = api.repo(Repo::with_revision(
1291 model_id_str.clone(),
1292 RepoType::Model,
1293 revision.clone(),
1294 ));
1295
1296 let mut gate_filenames = vec![];
1297 for rfilename in
1298 api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors"))
1299 {
1300 gate_filenames.push(api_get_file!(api, &rfilename, model_id));
1301 }
1302 assert_eq!(
1303 gate_filenames.len(),
1304 1,
1305 "Gate model ID must contain only one .safetensors file"
1306 );
1307
1308 let vb = from_mmaped_safetensors(
1309 gate_filenames.clone(),
1310 vec![],
1311 Some(dtype),
1312 dev,
1313 vec![None],
1314 silent,
1315 None,
1316 |_| true,
1317 Arc::new(|_| DeviceForLoadTensor::Base),
1318 )?;
1319 info!(
1320 "Loaded gating layers from `{}`",
1321 gate_filenames[0].display()
1322 );
1323 Some(vb)
1324 } else {
1325 None
1326 };
1327
1328 self.model.create_anymoe_layers(
1329 vbs.clone(),
1330 config.clone(),
1331 (prefix.clone(), mlp.clone()),
1332 layers.clone(),
1333 expert_type.clone(),
1334 gate_vb.clone(),
1335 )?;
1336
1337 Ok(())
1338 }
1339 fn amoe_supported(&self) -> bool {
1340 self.model.amoe_supported()
1341 }
1342}