1use super::isq::ImatrixDataSource;
2use super::llg::build_llg_factory;
3use super::{
4 get_model_paths, get_xlora_paths, text_models_inputs_processor::ModelInputs, AdapterKind,
5 CacheManager, GeneralMetadata, Loader, ModelKind, ModelPaths, NormalModel, NormalModelLoader,
6 TokenSource,
7};
8use super::{
9 AnyMoePipelineMixin, CacheManagerMixin, EitherCache, ForwardInputsResult, IsqOrganization,
10 IsqPipelineMixin, MetadataMixin, ModelCategory, PreProcessingMixin,
11};
12use super::{
13 AutoNormalLoader, DeepSeekV2Loader, DeepSeekV3Loader, GLM4Loader, Gemma2Loader, GemmaLoader,
14 GraniteMoeHybridLoader, LlamaLoader, MistralLoader, MixtralLoader, NormalLoaderType,
15 Phi2Loader, Phi3Loader, Phi3_5MoELoader, Qwen2Loader, Qwen3Loader, Qwen3MoELoader,
16 SmolLm3Loader, Starcoder2Loader,
17};
18use crate::amoe::AnyMoeExpertType;
19use crate::attention::ATTENTION_CHUNK_SIZE;
20use crate::device_map::{self, DeviceMapper};
21use crate::distributed::{self, WorkerTransferData};
22use crate::kv_cache::{FullCacheManager, HybridCacheManager, NormalCacheManager};
23use crate::lora::Ordering;
24use crate::paged_attention::{calculate_cache_config, AttentionImplementation, CacheEngine};
25use crate::pipeline::chat_template::{calculate_eos_tokens, GenerationConfig};
26use crate::pipeline::isq::UqffFullSer;
27use crate::pipeline::loaders::auto_device_map;
28use crate::pipeline::loaders::QuantizationConfigShim;
29use crate::pipeline::sampling::sample_and_add_toks;
30use crate::pipeline::text_models_inputs_processor::make_prompt_chunk;
31use crate::pipeline::{get_chat_template, Modalities, SupportedModality};
32use crate::pipeline::{ChatTemplate, LocalModelPaths};
33use crate::prefix_cacher::PrefixCacheManagerV2;
34use crate::sequence::Sequence;
35use crate::utils::tokenizer::get_tokenizer;
36use crate::utils::varbuilder_utils::DeviceForLoadTensor;
37use crate::utils::{
38 progress::{new_multi_progress, ProgressScopeGuard},
39 tokens::get_token,
40 varbuilder_utils::from_mmaped_safetensors,
41};
42use crate::xlora_models::NonGranularState;
43use crate::{
44 api_dir_list, api_get_file, get_mut_arcmutex, get_paths, get_uqff_paths, lora_model_loader,
45 normal_model_loader, normal_model_loader_sharded, xlora_model_loader, DeviceMapSetting,
46 PagedAttentionConfig, Pipeline, Topology, TryIntoDType, GLOBAL_HF_CACHE,
47};
48use anyhow::Result;
49use candle_core::{Device, Tensor, Var};
50use hf_hub::Cache;
51use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
52use mistralrs_quant::log::once_log_info;
53use mistralrs_quant::{
54 AfqLayer, GgufMatMul, HqqLayer, ImmediateIsqOverride, IsqType, QuantizedSerdeType,
55};
56use rand_isaac::Isaac64Rng;
57use regex_automata::meta::Regex;
58use std::any::Any;
59use std::borrow::Cow;
60use std::path::{Path, PathBuf};
61use std::str::FromStr;
62use std::sync::{Arc, RwLock};
63use std::time::Instant;
64use std::{env, fs};
65use tokenizers::Tokenizer;
66use tokio::sync::Mutex;
67use tracing::{info, warn};
68
69pub struct NormalPipeline {
70 model: Box<dyn NormalModel + Send + Sync>,
71 tokenizer: Arc<Tokenizer>,
72 no_kv_cache: bool,
73 chat_template: Arc<ChatTemplate>,
74 non_granular_state: Option<NonGranularState>,
75 model_id: String,
76 metadata: Arc<GeneralMetadata>,
77 topology: Option<Topology>,
78 silent: bool,
79 organization: IsqOrganization,
80 template_filename: Option<PathBuf>,
82 generation_config: Option<PathBuf>,
83 config: String,
84 imatrix: Option<PathBuf>,
85 mapper: Box<dyn DeviceMapper + Send + Sync>,
86}
87
88pub struct NormalLoader {
90 inner: Box<dyn NormalModelLoader>,
91 model_id: String,
92 config: NormalSpecificConfig,
93 xlora_model_id: Option<String>,
94 lora_adapter_ids: Option<Vec<String>>,
95 kind: ModelKind,
96 xlora_order: Option<Ordering>,
97 no_kv_cache: bool,
98 chat_template: Option<String>,
99 tokenizer_json: Option<String>,
100 tgt_non_granular_index: Option<usize>,
101 token_source: RwLock<Option<TokenSource>>,
102 revision: RwLock<Option<String>>,
103 from_uqff: RwLock<Option<Vec<PathBuf>>>,
104 jinja_explicit: Option<String>,
105 hf_cache_path: Option<PathBuf>,
106}
107
108#[derive(Default)]
109pub struct NormalLoaderBuilder {
111 model_id: Option<String>,
112 config: NormalSpecificConfig,
113 xlora_model_id: Option<String>,
114 lora_adapter_ids: Option<Vec<String>>,
115 kind: ModelKind,
116 xlora_order: Option<Ordering>,
117 no_kv_cache: bool,
118 chat_template: Option<String>,
119 tokenizer_json: Option<String>,
120 tgt_non_granular_index: Option<usize>,
121 jinja_explicit: Option<String>,
122 hf_cache_path: Option<PathBuf>,
123}
124
125#[derive(Clone, Default)]
126pub struct NormalSpecificConfig {
128 pub topology: Option<Topology>,
129 pub organization: IsqOrganization,
130 pub write_uqff: Option<PathBuf>,
131 pub from_uqff: Option<Vec<PathBuf>>,
132 pub imatrix: Option<PathBuf>,
133 pub calibration_file: Option<PathBuf>,
134 pub hf_cache_path: Option<PathBuf>,
135 pub matformer_config_path: Option<PathBuf>,
136 pub matformer_slice_name: Option<String>,
137}
138
139impl NormalLoaderBuilder {
140 pub fn new(
141 config: NormalSpecificConfig,
142 chat_template: Option<String>,
143 tokenizer_json: Option<String>,
144 model_id: Option<String>,
145 no_kv_cache: bool,
146 jinja_explicit: Option<String>,
147 ) -> Self {
148 Self {
149 config,
150 chat_template,
151 tokenizer_json,
152 model_id,
153 kind: ModelKind::Normal,
154 jinja_explicit,
155 no_kv_cache,
156 ..Default::default()
157 }
158 }
159
160 fn with_adapter(
161 mut self,
162 xlora_model_id: String,
163 xlora_order: Ordering,
164 no_kv_cache: bool,
165 tgt_non_granular_index: Option<usize>,
166 ) -> Self {
167 self.xlora_model_id = Some(xlora_model_id);
168 self.xlora_order = Some(xlora_order);
169 self.no_kv_cache = no_kv_cache;
170 self.tgt_non_granular_index = tgt_non_granular_index;
171 self.model_id = if let Some(id) = self.model_id {
172 Some(id)
173 } else {
174 info!(
175 "Using adapter base model ID: `{}`",
176 self.xlora_order.as_ref().unwrap().base_model_id
177 );
178 Some(self.xlora_order.as_ref().unwrap().base_model_id.clone())
179 };
180 self
181 }
182
183 pub fn with_xlora(
184 mut self,
185 xlora_model_id: String,
186 xlora_order: Ordering,
187 no_kv_cache: bool,
188 tgt_non_granular_index: Option<usize>,
189 ) -> Self {
190 self.kind = ModelKind::Adapter {
191 adapter: AdapterKind::XLora,
192 };
193 self.with_adapter(
194 xlora_model_id,
195 xlora_order,
196 no_kv_cache,
197 tgt_non_granular_index,
198 )
199 }
200
201 pub fn with_lora(mut self, lora_adapter_ids: Vec<String>) -> Self {
202 self.kind = ModelKind::Adapter {
203 adapter: AdapterKind::Lora,
204 };
205 self.lora_adapter_ids = Some(lora_adapter_ids);
206 self
207 }
208
209 pub fn hf_cache_path(mut self, hf_cache_path: PathBuf) -> Self {
210 self.hf_cache_path = Some(hf_cache_path);
211 self
212 }
213
214 pub fn build(self, loader_tp: Option<NormalLoaderType>) -> anyhow::Result<Box<dyn Loader>> {
217 let loader: Box<dyn NormalModelLoader> = match loader_tp {
218 Some(NormalLoaderType::Mistral) => Box::new(MistralLoader),
219 Some(NormalLoaderType::Gemma) => Box::new(GemmaLoader),
220 Some(NormalLoaderType::Llama) => Box::new(LlamaLoader),
221 Some(NormalLoaderType::Mixtral) => Box::new(MixtralLoader),
222 Some(NormalLoaderType::Phi2) => Box::new(Phi2Loader),
223 Some(NormalLoaderType::Phi3) => Box::new(Phi3Loader),
224 Some(NormalLoaderType::Qwen2) => Box::new(Qwen2Loader),
225 Some(NormalLoaderType::Gemma2) => Box::new(Gemma2Loader),
226 Some(NormalLoaderType::Starcoder2) => Box::new(Starcoder2Loader),
227 Some(NormalLoaderType::Phi3_5MoE) => Box::new(Phi3_5MoELoader),
228 Some(NormalLoaderType::DeepSeekV2) => Box::new(DeepSeekV2Loader),
229 Some(NormalLoaderType::DeepSeekV3) => Box::new(DeepSeekV3Loader),
230 Some(NormalLoaderType::Qwen3) => Box::new(Qwen3Loader),
231 Some(NormalLoaderType::GLM4) => Box::new(GLM4Loader),
232 Some(NormalLoaderType::Qwen3Moe) => Box::new(Qwen3MoELoader),
233 Some(NormalLoaderType::SmolLm3) => Box::new(SmolLm3Loader),
234 Some(NormalLoaderType::GraniteMoeHybrid) => Box::new(GraniteMoeHybridLoader),
235 None => Box::new(AutoNormalLoader),
236 };
237 Ok(Box::new(NormalLoader {
238 inner: loader,
239 model_id: self.model_id.unwrap(),
240 config: self.config,
241 xlora_model_id: self.xlora_model_id,
242 lora_adapter_ids: self.lora_adapter_ids,
243 kind: self.kind,
244 xlora_order: self.xlora_order,
245 no_kv_cache: self.no_kv_cache,
246 chat_template: self.chat_template,
247 tokenizer_json: self.tokenizer_json,
248 tgt_non_granular_index: self.tgt_non_granular_index,
249 jinja_explicit: self.jinja_explicit,
250 token_source: RwLock::new(None),
251 revision: RwLock::new(None),
252 from_uqff: RwLock::new(None),
253 hf_cache_path: self.hf_cache_path,
254 }))
255 }
256}
257
258impl Loader for NormalLoader {
259 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
260 fn load_model_from_hf(
261 &self,
262 revision: Option<String>,
263 token_source: TokenSource,
264 dtype: &dyn TryIntoDType,
265 device: &Device,
266 silent: bool,
267 mapper: DeviceMapSetting,
268 in_situ_quant: Option<IsqType>,
269 paged_attn_config: Option<PagedAttentionConfig>,
270 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
271 let _progress_guard = ProgressScopeGuard::new(silent);
272 let cache = self
273 .hf_cache_path
274 .clone()
275 .map(Cache::new)
276 .unwrap_or_default();
277 GLOBAL_HF_CACHE.get_or_init(|| cache);
278
279 let paths: anyhow::Result<Box<dyn ModelPaths>> = get_paths!(
280 LocalModelPaths,
281 &token_source,
282 revision.clone(),
283 self,
284 None,
285 None,
286 silent,
287 self.config.from_uqff.is_some()
288 );
289 if let Some(from_uqff) = self.config.from_uqff.clone() {
290 *self.from_uqff.write().unwrap() = Some(get_uqff_paths!(&from_uqff, self, silent));
291 }
292 *self
293 .token_source
294 .write()
295 .expect("Failed to write to token source") = Some(token_source);
296 *self.revision.write().expect("Failed to write to revision") = revision;
297 self.load_model_from_path(
298 &paths?,
299 dtype,
300 device,
301 silent,
302 mapper,
303 in_situ_quant,
304 paged_attn_config,
305 )
306 }
307
308 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
309 fn load_model_from_path(
310 &self,
311 paths: &Box<dyn ModelPaths>,
312 dtype: &dyn TryIntoDType,
313 device: &Device,
314 silent: bool,
315 mut mapper: DeviceMapSetting,
316 mut in_situ_quant: Option<IsqType>,
317 mut paged_attn_config: Option<PagedAttentionConfig>,
318 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
319 let _progress_guard = ProgressScopeGuard::new(silent);
320 let config = std::fs::read_to_string(paths.get_config_filename())?;
321
322 if !self.inner.supports_paged_attention(&config)? {
323 paged_attn_config = None;
324 }
325
326 info!("Prompt chunk size is {ATTENTION_CHUNK_SIZE}.");
327
328 let use_nccl = mistralrs_quant::distributed::use_nccl();
329
330 let available_devices = if let Ok(payload) = env::var(distributed::IS_DAEMON_FLAG) {
331 let payload: WorkerTransferData = serde_json::from_str(&payload)?;
332 let WorkerTransferData::Init { id: _, worker_rank } = payload;
333 vec![candle_core::Device::new_cuda(worker_rank + 1)?]
334 } else if use_nccl {
335 vec![candle_core::Device::new_cuda(0)?]
336 } else {
337 device_map::get_all_similar_devices(device)?
338 };
339 #[cfg(feature = "cuda")]
340 for device in &available_devices {
341 if let Device::Cuda(dev) = device {
342 unsafe { dev.disable_event_tracking() };
343 }
344 }
345 let device = if use_nccl || cfg!(feature = "ring") {
346 available_devices[0].clone()
347 } else {
348 device.clone()
349 };
350
351 if use_nccl || cfg!(feature = "ring") {
353 mapper = DeviceMapSetting::DummyNccl {
354 nm_device: available_devices[0].clone(),
355 };
356 } else if let DeviceMapSetting::Auto(params) = mapper.clone() {
357 let dtype = dtype.try_into_dtype(&available_devices.iter().collect::<Vec<_>>())?;
359
360 if QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)? != 1 {
362 in_situ_quant = None;
363 }
364
365 let (layer_sizes_in_bytes, non_mapped_size_in_bytes, total_model_size_in_bytes) =
368 if let Some(serialized) = &*self.from_uqff.read().unwrap() {
369 let weight_pack_factor = {
370 let ser_artifacts = unsafe {
371 candle_core::safetensors::MmapedSafetensors::multi(serialized)?
372 };
373 let mut total_pack_factors = 0;
374 let total_tensors = ser_artifacts.tensors().len();
375 for (_, artifact) in ser_artifacts.tensors() {
376 let artifact = artifact.data();
377 let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
379 let pack_factor = match QuantizedSerdeType::try_from(isq_type as usize)?
380 {
381 QuantizedSerdeType::Hqq => {
382 HqqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
383 .pack_factor(dtype)
384 }
385 QuantizedSerdeType::Gguf => {
386 GgufMatMul::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
387 .pack_factor(dtype)
388 }
389 QuantizedSerdeType::Fp8 => IsqType::F8E4M3.pack_factor(dtype),
390 QuantizedSerdeType::Unquant => 1,
391 QuantizedSerdeType::Afq => {
392 AfqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
393 .pack_factor(dtype)
394 }
395 };
396 total_pack_factors += pack_factor;
397 }
398
399 total_pack_factors / total_tensors
400 };
401
402 let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
403 &config,
404 dtype,
405 weight_pack_factor,
406 None,
407 )?;
408 let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
409 &config,
410 dtype,
411 weight_pack_factor,
412 None,
413 )?;
414 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
415 (
416 layer_sizes_in_bytes,
417 non_mapped_size_in_bytes,
418 layer_sizes_sum + non_mapped_size_in_bytes,
419 )
420 } else if let Some(isq) = in_situ_quant {
421 let weight_pack_factor = isq.pack_factor(dtype);
422 let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
423 &config,
424 dtype,
425 weight_pack_factor,
426 None,
427 )?;
428 let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
429 &config,
430 dtype,
431 weight_pack_factor,
432 None,
433 )?;
434 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
435 (
436 layer_sizes_in_bytes,
437 non_mapped_size_in_bytes,
438 layer_sizes_sum + non_mapped_size_in_bytes,
439 )
440 } else {
441 let weight_pack_factor =
443 QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)?;
444 let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
445 &config,
446 dtype,
447 weight_pack_factor,
448 None,
449 )?;
450 let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
451 &config,
452 dtype,
453 weight_pack_factor,
454 None,
455 )?;
456 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
457 (
458 layer_sizes_in_bytes,
459 non_mapped_size_in_bytes,
460 layer_sizes_sum + non_mapped_size_in_bytes,
461 )
462 };
463
464 let new = auto_device_map::get_device_layers(
465 &*self.inner,
466 &config,
467 self.inner.num_layers(&config)?,
468 layer_sizes_in_bytes,
469 non_mapped_size_in_bytes,
470 total_model_size_in_bytes,
471 &available_devices,
472 dtype,
473 ¶ms,
474 paged_attn_config.as_ref(),
475 )?;
476 mapper = DeviceMapSetting::Map(new);
477 }
478
479 let pipeline_mapper = mapper.into_mapper(
480 self.inner.num_layers(&config)?,
481 &device,
482 self.config.topology.as_ref(),
483 )?;
484 let mapper = mapper.into_mapper(
485 self.inner.num_layers(&config)?,
486 &device,
487 self.config.topology.as_ref(),
488 )?;
489 let mut layer_devices = Vec::new();
490 for layer in 0..self.inner.num_layers(&config)? {
491 let device = mapper.device_for(layer, false).cloned();
492 layer_devices.push(device);
493 }
494 let dtype = mapper.get_min_dtype(dtype)?;
495
496 let mapping_uses_cpu = mapper.get_unique_devices().iter().any(Device::is_cpu);
499 if mapping_uses_cpu && paged_attn_config.is_some() {
500 warn!("Device mapping contains a mix of GPU and CPU. There is no CPU support for PagedAttention, disabling PagedAttention.");
501 paged_attn_config = None;
502 }
503
504 info!("Model config: {:?}", self.inner.get_config_repr(&config)?);
505 if crate::using_flash_attn() {
506 once_log_info("FlashAttention is enabled.");
507 }
508
509 let topology_overrides = self
510 .config
511 .topology
512 .as_ref()
513 .map(|topology| {
514 topology
515 .pattern_overrides()
516 .into_iter()
517 .map(|(regex, layer)| ImmediateIsqOverride {
518 predicate: regex,
519 ty: layer.isq,
520 device: layer.device.clone(),
521 })
522 .collect::<Vec<_>>()
523 })
524 .unwrap_or_default();
525 let has_override_isq = topology_overrides
526 .iter()
527 .any(|override_entry| override_entry.ty.is_some());
528 let topology_requires_post_quant = self
529 .config
530 .topology
531 .as_ref()
532 .is_some_and(|topology| topology.requires_post_quantization());
533
534 let allow_immediate_cli = self.config.imatrix.is_none()
535 && self.config.calibration_file.is_none()
536 && !device.is_cuda()
537 && in_situ_quant.is_some();
538
539 let mut immediate_ty = None;
540 let mut immediate_predicates = Vec::new();
541 if allow_immediate_cli {
542 immediate_ty = in_situ_quant;
543 immediate_predicates =
544 if matches!(self.config.organization, IsqOrganization::MoeExpertsOnly) {
545 self.inner.immediate_isq_predicates_moqe(&config)?
546 } else {
547 self.inner.immediate_isq_predicates(&config)?
548 };
549 info!("Applying ISQ to {in_situ_quant:?}");
550 if immediate_predicates.is_empty() {
551 warn!("No predicates for this model and ISQ setting detected. ISQ will not be applied to any weights!");
552 }
553 }
554
555 let use_immediate = allow_immediate_cli || has_override_isq;
556 if use_immediate {
557 mistralrs_quant::set_immediate_isq_with_overrides(
558 immediate_ty,
559 immediate_predicates.clone(),
560 topology_overrides.clone(),
561 );
562 }
563
564 let mut loading_isq = if use_immediate {
566 false
567 } else {
568 in_situ_quant.is_some()
569 };
570 if self.config.imatrix.is_some() || self.config.calibration_file.is_some() {
571 loading_isq = true;
572 }
573 loading_isq |= topology_requires_post_quant;
574
575 if self.config.imatrix.is_some() && self.config.calibration_file.is_some() {
576 anyhow::bail!(
577 "`imatrix` and `calibration_file` were both specified, this is not allowed."
578 );
579 }
580
581 let load_device = if !loading_isq || self.config.calibration_file.is_some() {
583 loading_isq = false;
584 device.clone()
585 } else {
586 Device::Cpu
587 };
588
589 let is_xlora = self.kind.is_adapted_and(|a| a.is_x_lora());
590
591 let attention_mechanism = if paged_attn_config.is_some() {
592 AttentionImplementation::PagedAttention
593 } else {
594 AttentionImplementation::Eager
595 };
596
597 let multi_progress = Arc::new(new_multi_progress());
598
599 let matformer_slicing_config = if let Some(matformer_path) =
601 &self.config.matformer_config_path
602 {
603 use crate::matformer::{MatformerConfig, MatformerSliceConfig};
604 info!("Loading Matformer config from {:?}", matformer_path);
605 let config = Arc::new(MatformerConfig::from_file(matformer_path)?);
606
607 if let Some(slice_name) = &self.config.matformer_slice_name {
608 info!("Using Matformer slice: {}", slice_name);
609 Some(MatformerSliceConfig::new(slice_name.clone(), config))
610 } else {
611 warn!("Matformer config loaded but no slice name specified. Models will use their default slice.");
614 None
615 }
616 } else {
617 None
618 };
619
620 let mut model = if use_nccl || cfg!(feature = "ring") {
621 let (mapper, sharded_vb) = distributed::prepare_distributed_mapper(
622 dtype,
623 &device,
624 &available_devices,
625 silent,
626 &config,
627 loading_isq,
628 self.config.from_uqff.is_some(),
629 self.config.organization,
630 &*self.inner,
631 paths.as_ref(),
632 )?;
633
634 match self.kind {
636 ModelKind::Normal => normal_model_loader_sharded!(
637 sharded_vb,
638 config,
639 self.inner,
640 mapper,
641 loading_isq,
642 device.clone(),
643 attention_mechanism,
644 multi_progress.clone(),
645 matformer_slicing_config.clone(),
646 ),
647 ModelKind::Adapter {
648 adapter: AdapterKind::XLora,
649 } => xlora_model_loader!(
650 paths,
651 Some(dtype),
652 &load_device,
653 layer_devices.clone(),
654 config,
655 self.inner,
656 silent,
657 mapper,
658 loading_isq,
659 device.clone(),
660 multi_progress.clone(),
661 matformer_slicing_config.clone(),
662 ),
663 ModelKind::Adapter {
664 adapter: AdapterKind::Lora,
665 } => lora_model_loader!(
666 paths,
667 Some(dtype),
668 &load_device,
669 layer_devices.clone(),
670 config,
671 self.inner,
672 silent,
673 mapper,
674 loading_isq,
675 self.config.from_uqff.is_some(),
676 device.clone(),
677 attention_mechanism,
678 matches!(self.config.organization, IsqOrganization::MoeExpertsOnly),
679 multi_progress.clone(),
680 matformer_slicing_config.clone(),
681 ),
682 _ => unreachable!(),
683 }
684 } else {
685 match self.kind {
686 ModelKind::Normal => normal_model_loader!(
687 paths,
688 Some(dtype),
689 &load_device,
690 layer_devices.clone(),
691 config,
692 self.inner,
693 silent,
694 mapper,
695 loading_isq,
696 self.config.from_uqff.is_some(),
697 device.clone(),
698 attention_mechanism,
699 matches!(self.config.organization, IsqOrganization::MoeExpertsOnly),
700 multi_progress.clone(),
701 matformer_slicing_config.clone(),
702 ),
703 ModelKind::Adapter {
704 adapter: AdapterKind::XLora,
705 } => xlora_model_loader!(
706 paths,
707 Some(dtype),
708 &load_device,
709 layer_devices.clone(),
710 config,
711 self.inner,
712 silent,
713 mapper,
714 loading_isq,
715 device.clone(),
716 multi_progress.clone(),
717 matformer_slicing_config.clone(),
718 ),
719 ModelKind::Adapter {
720 adapter: AdapterKind::Lora,
721 } => lora_model_loader!(
722 paths,
723 Some(dtype),
724 &load_device,
725 layer_devices.clone(),
726 config,
727 self.inner,
728 silent,
729 mapper,
730 loading_isq,
731 self.config.from_uqff.is_some(),
732 device.clone(),
733 attention_mechanism,
734 matches!(self.config.organization, IsqOrganization::MoeExpertsOnly),
735 multi_progress.clone(),
736 matformer_slicing_config.clone(),
737 ),
738 _ => unreachable!(),
739 }
740 };
741
742 let tokenizer = get_tokenizer(paths.get_tokenizer_filename(), None)?;
743 let gen_conf: Option<GenerationConfig> = paths.get_gen_conf_filename().and_then(|f| {
744 match serde_json::from_str::<GenerationConfig>(&fs::read_to_string(f).unwrap()) {
745 Ok(conf) => Some(conf),
746 Err(e) => {
747 warn!("Failed to parse generation_config.json: {}", e);
748 None
749 }
750 }
751 });
752
753 let chat_template_explicit = paths
754 .get_chat_template_explicit()
755 .as_ref()
756 .map(|x| x.to_string_lossy().to_string());
757 let chat_template = get_chat_template(
758 paths,
759 self.jinja_explicit.as_ref(),
760 chat_template_explicit.as_ref(),
761 self.chat_template.as_ref(),
762 None,
763 );
764
765 if let Some(calibration_file) = &self.config.calibration_file {
766 let calibration_data = std::fs::read_to_string(calibration_file)?;
767 let tokens = tokenizer
769 .encode_fast(calibration_data, false)
770 .map_err(anyhow::Error::msg)?
771 .get_ids()
772 .to_vec();
773 info!(
774 "Collecting imatrix from calibration file `{}` of {} tokens.",
775 calibration_file.display(),
776 tokens.len()
777 );
778 let bos_toks = chat_template.bos_tok().map(|b| vec![b]).unwrap_or_default();
779 let bos_tok_id = tokenizer
780 .token_to_id(&bos_toks[0])
781 .expect("Somehow the bos token is not present.");
782
783 match self.config.organization {
784 IsqOrganization::Default => model.begin_track_stats()?,
785 IsqOrganization::MoeExpertsOnly => model.begin_track_stats_moe_experts_only()?,
786 }
787
788 const CHUNK_SIZE: usize = 1024;
789 let n_chunks = tokens.len().div_ceil(CHUNK_SIZE);
790 let start = Instant::now();
791 for (i, chunk) in tokens.chunks(CHUNK_SIZE).enumerate() {
792 let chunk = [vec![bos_tok_id], chunk.to_vec()].concat();
793 let chunk_len = chunk.len();
794
795 let start = Instant::now();
796 let inputs = make_prompt_chunk(
797 0,
798 vec![&chunk],
799 &[0],
800 &load_device,
801 None,
802 false,
803 None,
804 Some(pipeline_mapper.as_ref()),
805 )?;
806
807 model.forward(
808 &inputs.input.to_device(model.device())?,
809 &inputs.positions,
810 inputs.context_lens.clone(),
811 inputs.position_ids.clone(),
812 None,
813 &inputs.flash_meta.clone(),
814 )?;
815
816 match model.cache_mut() {
817 EitherCache::Full(full) => {
818 for layer in &mut *full.lock() {
819 *layer = None
820 }
821 }
822 EitherCache::Normal(normal) => {
823 for layer in &mut *normal.lock().unwrap().0 {
824 layer.reset();
825 }
826 }
827 EitherCache::Hybrid(hybrid) => {
828 hybrid.lock().unwrap().reset();
829 }
830 }
831
832 let end = Instant::now();
833 info!(
834 "Processed chunk {}/{n_chunks} ({chunk_len} tokens), {:.2}s",
835 i + 1,
836 end.duration_since(start).as_secs_f32()
837 );
838 }
839 load_device.synchronize()?;
840 let end = Instant::now();
841 info!(
842 "Finished collecting imatrix in {:.2}s",
843 end.duration_since(start).as_secs_f32()
844 );
845 }
846
847 let should_serialize = self.config.write_uqff.is_some();
849 let should_quantize_pass = loading_isq;
850
851 if (should_quantize_pass || should_serialize) && self.config.from_uqff.is_none() {
852 let imatrix_source = if should_quantize_pass {
853 match (
854 self.config.imatrix.as_ref(),
855 self.config.calibration_file.is_some(),
856 ) {
857 (None, false) => None,
858 (Some(file), false) => Some(ImatrixDataSource::File(file)),
859 (None, true) => Some(ImatrixDataSource::Collected),
860 (Some(_), true) => unreachable!(),
861 }
862 } else {
863 None
864 };
865
866 if should_quantize_pass {
867 info!("Applying ISQ to all ranks.");
868 } else {
869 info!("Serializing existing ISQ tensors without additional quantization.");
870 }
871
872 let multi_progress = Arc::new(new_multi_progress());
873
874 model.quantize(
875 in_situ_quant,
876 model.device().clone(),
877 self.config.topology.as_ref(),
878 silent,
879 imatrix_source,
880 self.config.organization,
881 should_quantize_pass,
882 self.config.write_uqff.as_ref(),
883 UqffFullSer {
884 tokenizer: &tokenizer,
885 template_filename: paths.get_template_filename(),
886 generation_config: paths.get_gen_conf_filename(),
887 config: config.clone(),
888 processor_filename: &None,
889 preprocessor_filename: &None,
890 modules: None,
891 module_paths: None,
892 },
893 multi_progress.clone(),
894 )?;
895 } else if let Some(from_uqff) = &*self.from_uqff.read().unwrap() {
896 model.load_from_artifacts(
897 device.clone(),
898 self.config.topology.as_ref(),
899 silent,
900 from_uqff,
901 )?;
902 }
903
904 let paged_attn_config = if matches!(
905 self.kind,
906 ModelKind::Adapter {
907 adapter: AdapterKind::XLora
908 }
909 ) {
910 warn!(
911 "Adapter parallel_models do not currently support PagedAttention, running without"
912 );
913 None
914 } else {
915 paged_attn_config
916 };
917
918 let (cache_config, cache_engine) = if let Some(paged_attn_config) = paged_attn_config {
919 let cache_config = calculate_cache_config(
920 paged_attn_config.mem_gpu,
921 paged_attn_config.block_size,
922 dtype,
923 paged_attn_config.cache_type,
924 model.config(),
925 &device,
926 &pipeline_mapper
927 .get_unique_devices()
928 .into_iter()
929 .map(Some)
930 .collect::<Vec<_>>(),
931 silent,
932 )?;
933
934 let mut layer_devices = Vec::new();
935 for layer in 0..self.inner.num_layers(&config)? {
936 let device = model.get_layers().1.device_for(layer, false).cloned();
937 layer_devices.push(device);
938 }
939 let cache_engine = CacheEngine::new(
940 model.config(),
941 &cache_config,
942 dtype,
943 model.device(),
944 layer_devices.clone(),
945 )?;
946
947 (Some(cache_config), Some(cache_engine))
948 } else {
949 (None, None)
950 };
951
952 let max_seq_len = model.max_seq_len();
953 let llg_factory = build_llg_factory(tokenizer.clone())?;
954 let num_hidden_layers = match model.cache() {
955 EitherCache::Full(full) => full.lock().len(),
956 EitherCache::Normal(normal) => normal.lock().unwrap().0.len(),
957 EitherCache::Hybrid(hybrid) => hybrid.lock().unwrap().num_layers(),
958 };
959 let eos = calculate_eos_tokens(&chat_template, gen_conf, &tokenizer);
960 let sliding_window = model.config().sliding_window;
961 let model_metadata = Arc::new(model.config().clone());
962
963 Ok(Arc::new(Mutex::new(NormalPipeline {
964 model,
965 tokenizer: tokenizer.into(),
966 no_kv_cache: self.no_kv_cache,
967 chat_template: Arc::new(chat_template),
968 non_granular_state: self.tgt_non_granular_index.map(|tgt_non_granular_index| {
969 NonGranularState {
970 non_granular_index: Arc::new(Mutex::new(0)),
971 tgt_non_granular_index,
972 }
973 }),
974 model_id: self.model_id.clone(),
975 metadata: Arc::new(GeneralMetadata {
976 max_seq_len,
977 llg_factory: Some(llg_factory),
978 no_kv_cache: self.no_kv_cache,
979 no_prefix_cache: is_xlora,
980 num_hidden_layers,
981 eos_tok: eos,
982 kind: self.kind.clone(),
983 is_xlora,
984 activation_dtype: dtype,
985 sliding_window,
986 cache_config,
987 cache_engine,
988 model_metadata: Some(model_metadata),
989 modalities: Modalities {
990 input: vec![SupportedModality::Text],
991 output: vec![SupportedModality::Text],
992 },
993 }),
994 topology: self.config.topology.clone(),
995 silent,
996 organization: self.config.organization,
997 template_filename: paths.get_template_filename().clone(),
998 generation_config: paths.get_gen_conf_filename().cloned(),
999 config,
1000 imatrix: self.config.imatrix.clone(),
1001 mapper: pipeline_mapper,
1002 })))
1003 }
1004
1005 fn get_id(&self) -> String {
1006 self.model_id.clone()
1007 }
1008
1009 fn get_kind(&self) -> ModelKind {
1010 self.kind.clone()
1011 }
1012}
1013
1014impl PreProcessingMixin for NormalPipeline {
1015 fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
1016 Some(self.chat_template.clone())
1017 }
1018 fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
1019 None
1020 }
1021}
1022
1023impl IsqPipelineMixin for NormalPipeline {
1024 fn re_isq_model(&mut self, dtype: IsqType) -> Result<()> {
1025 let device = self.device().clone();
1026 let multi_progress = Arc::new(new_multi_progress());
1027 self.model.quantize(
1028 Some(dtype),
1029 device.clone(),
1030 self.topology.as_ref(),
1031 self.silent,
1032 self.imatrix.as_ref().map(ImatrixDataSource::File),
1033 self.organization,
1034 true,
1035 None,
1036 UqffFullSer {
1037 tokenizer: &self.tokenizer,
1038 template_filename: &self.template_filename,
1039 generation_config: self.generation_config.as_ref(),
1040 config: self.config.clone(),
1041 processor_filename: &None,
1042 preprocessor_filename: &None,
1043 modules: None,
1044 module_paths: None,
1045 },
1046 multi_progress.clone(),
1047 )?;
1048 Ok(())
1049 }
1050}
1051
1052impl CacheManagerMixin for NormalPipeline {
1053 fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
1054 match self.model.cache() {
1055 EitherCache::Full(_) => FullCacheManager.clone_in_cache(self, seqs, false),
1056 EitherCache::Normal(_) => NormalCacheManager.clone_in_cache(self, seqs, false),
1057 EitherCache::Hybrid(_) => HybridCacheManager.clone_in_cache(self, seqs, false),
1058 }
1059 }
1060 fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
1061 match self.model.cache() {
1062 EitherCache::Full(_) => FullCacheManager.clone_out_cache(self, seqs, false),
1063 EitherCache::Normal(_) => NormalCacheManager.clone_out_cache(self, seqs, false),
1064 EitherCache::Hybrid(_) => HybridCacheManager.clone_out_cache(self, seqs, false),
1065 }
1066 }
1067 fn set_none_cache(
1068 &self,
1069 seqs: &mut [&mut Sequence],
1070 reset_non_granular: bool,
1071 modify_draft_cache: bool,
1072 load_preallocated_cache: bool,
1073 ) {
1074 match self.model.cache() {
1075 EitherCache::Full(_) => {
1076 FullCacheManager.set_none_cache(self, seqs, modify_draft_cache, false)
1077 }
1078 EitherCache::Normal(_) => NormalCacheManager.set_none_cache(
1079 self,
1080 seqs,
1081 modify_draft_cache,
1082 load_preallocated_cache,
1083 ),
1084 EitherCache::Hybrid(_) => HybridCacheManager.set_none_cache(
1085 self,
1086 seqs,
1087 modify_draft_cache,
1088 load_preallocated_cache,
1089 ),
1090 }
1091 if reset_non_granular {
1092 self.reset_non_granular_state()
1093 }
1094 }
1095 fn cache(&self) -> &EitherCache {
1096 self.model.cache()
1097 }
1098}
1099
1100impl MetadataMixin for NormalPipeline {
1101 fn device(&self) -> Device {
1102 self.model.device().clone()
1103 }
1104 fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
1105 Some(self.tokenizer.clone())
1106 }
1107 fn name(&self) -> String {
1108 self.model_id.clone()
1109 }
1110 fn reset_non_granular_state(&self) {
1111 if let Some(s) = self.non_granular_state.as_ref() {
1112 *self.cache().full().get_scalings_cache() = None;
1113 *get_mut_arcmutex!(s.non_granular_index) = 0;
1114 }
1115 }
1116 fn get_metadata(&self) -> Arc<GeneralMetadata> {
1117 self.metadata.clone()
1118 }
1119 fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
1120 Some(&*self.mapper)
1121 }
1122}
1123
1124#[async_trait::async_trait]
1125impl Pipeline for NormalPipeline {
1126 fn forward_inputs(
1127 &mut self,
1128 inputs: Box<dyn Any>,
1129 return_raw_logits: bool,
1130 ) -> Result<ForwardInputsResult, candle_core::Error> {
1131 let ModelInputs {
1132 input_ids,
1133 input_ids_full,
1134 seqlen_offsets,
1135 seqlen_offsets_full,
1136 context_lens,
1137 position_ids,
1138 paged_attn_meta,
1139 flash_meta,
1140 flash_meta_full,
1141 } = *inputs.downcast().expect("Downcast failed.");
1142 let metadata = self.get_metadata();
1143 let paged_attn_meta = match (&metadata.cache_engine, &paged_attn_meta) {
1144 (Some(cache_engine), Some(meta)) => Some((cache_engine, meta)),
1145 (Some(_), None) => {
1146 candle_core::bail!("Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.")
1148 }
1149 (None, Some(_)) => {
1150 candle_core::bail!("Forward step got a PagedAttention input metadata but there is no cache engine. Please raise an issue.")
1152 }
1153 (None, None) => None,
1154 };
1155 let logits = match self.model.is_xlora() {
1156 false => {
1157 let paged_attn_meta = paged_attn_meta
1158 .as_ref()
1159 .map(|meta| (meta.0.get_kv_cache().clone(), meta.1.clone()));
1160
1161 self.model.forward(
1162 &input_ids,
1163 &seqlen_offsets,
1164 context_lens,
1165 position_ids,
1166 paged_attn_meta.as_ref().map(|(a, b)| (a.clone(), b)),
1167 &flash_meta,
1168 )?
1169 }
1170 true => self.model.xlora_forward(
1171 &input_ids,
1172 input_ids_full.as_ref().unwrap_or(&input_ids),
1173 &seqlen_offsets,
1174 seqlen_offsets_full.as_ref().unwrap_or(&seqlen_offsets),
1175 self.no_kv_cache,
1176 &self.non_granular_state,
1177 context_lens,
1178 position_ids,
1179 &flash_meta,
1180 flash_meta_full.as_ref().unwrap_or(&flash_meta),
1181 )?,
1182 };
1183 if return_raw_logits {
1184 Ok(ForwardInputsResult::RawLogits { logits })
1185 } else {
1186 Ok(ForwardInputsResult::CausalGeneration { logits })
1187 }
1188 }
1189 async fn sample_causal_gen(
1190 &self,
1191 seqs: &mut [&mut Sequence],
1192 logits: Vec<Tensor>,
1193 prefix_cacher: &mut PrefixCacheManagerV2,
1194 disable_eos_stop: bool,
1195 rng: Arc<std::sync::Mutex<Isaac64Rng>>,
1196 ) -> Result<(), candle_core::Error> {
1197 sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
1198 }
1199 fn category(&self) -> ModelCategory {
1200 ModelCategory::Text
1201 }
1202}
1203
1204impl AnyMoePipelineMixin for NormalPipeline {
1205 fn amoe_finish_training(&mut self, gate_model_id: Option<String>) -> candle_core::Result<()> {
1206 self.model.finish_training(gate_model_id)
1207 }
1208 fn amoe_layer_vars(&self) -> Vec<Vec<Var>> {
1209 self.model.get_vars()
1210 }
1211 fn amoe_base_model_trainable_params(&self) -> usize {
1212 self.model.trainable_params()
1213 }
1214 fn amoe_take_cached_gating_outputs(&mut self) -> Vec<Tensor> {
1215 self.model.take_cached_gating_outputs()
1216 }
1217 fn amoe_create_layers(
1218 &mut self,
1219 model_ids: Vec<String>,
1220 token: &TokenSource,
1221 revision: Option<String>,
1222 match_regex: &str,
1223 config: crate::amoe::AnyMoeConfig,
1224 dtype: candle_core::DType,
1225 dev: &Device,
1226 (prefix, mlp): (String, String),
1227 layers: Vec<usize>,
1228 expert_type: AnyMoeExpertType,
1229 silent: bool,
1230 gate_model_id: Option<String>,
1231 ) -> candle_core::Result<()> {
1232 let mut vbs = Vec::new();
1233 let regex = Regex::new(match_regex).map_err(candle_core::Error::msg)?;
1235 for model_id in model_ids {
1236 let model_id_str = &model_id;
1237 let model_id = Path::new(&model_id);
1238
1239 let api = {
1240 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1241 let mut api = ApiBuilder::from_cache(cache)
1242 .with_progress(!silent)
1243 .with_token(get_token(token).map_err(candle_core::Error::msg)?);
1244 if let Ok(x) = std::env::var("HF_HUB_CACHE") {
1245 api = api.with_cache_dir(x.into());
1246 }
1247 api.build().map_err(candle_core::Error::msg)?
1248 };
1249 let revision = revision.clone().unwrap_or("main".to_string());
1250 let api = api.repo(Repo::with_revision(
1251 model_id_str.clone(),
1252 RepoType::Model,
1253 revision.clone(),
1254 ));
1255
1256 let mut filenames = vec![];
1257 for rfilename in
1258 api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors"))
1259 {
1260 filenames.push(api_get_file!(api, &rfilename, model_id));
1261 }
1262
1263 let regex = regex.clone();
1264 let match_regex_clone = match_regex.to_string();
1265 let layers_clone = layers.clone();
1266 let vb = from_mmaped_safetensors(
1267 filenames,
1268 vec![],
1269 Some(dtype),
1270 dev,
1271 vec![None],
1272 silent,
1273 None,
1274 move |key| {
1275 if regex.is_match(&key) {
1276 let last_layer_idx = key.find(&match_regex_clone).unwrap() - 1;
1279 let first_layer_idx = key[..last_layer_idx].rfind('.').unwrap();
1280 let layer_n = key[first_layer_idx + 1..last_layer_idx]
1281 .parse::<usize>()
1282 .unwrap();
1283 layers_clone.contains(&layer_n) || layers_clone.is_empty()
1284 } else {
1285 false
1286 }
1287 },
1288 Arc::new(|_| DeviceForLoadTensor::Base),
1289 )?;
1290 vbs.push(vb);
1291 }
1292
1293 let gate_vb = if let Some(gate_model_id) = gate_model_id {
1294 let model_id_str = &gate_model_id;
1295 let model_id = Path::new(&gate_model_id);
1296
1297 let api = {
1298 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1299 let mut api = ApiBuilder::from_cache(cache)
1300 .with_progress(!silent)
1301 .with_token(get_token(token).map_err(candle_core::Error::msg)?);
1302 if let Ok(x) = std::env::var("HF_HUB_CACHE") {
1303 api = api.with_cache_dir(x.into());
1304 }
1305 api.build().map_err(candle_core::Error::msg)?
1306 };
1307 let revision = revision.clone().unwrap_or("main".to_string());
1308 let api = api.repo(Repo::with_revision(
1309 model_id_str.clone(),
1310 RepoType::Model,
1311 revision.clone(),
1312 ));
1313
1314 let mut gate_filenames = vec![];
1315 for rfilename in
1316 api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors"))
1317 {
1318 gate_filenames.push(api_get_file!(api, &rfilename, model_id));
1319 }
1320 assert_eq!(
1321 gate_filenames.len(),
1322 1,
1323 "Gate model ID must contain only one .safetensors file"
1324 );
1325
1326 let vb = from_mmaped_safetensors(
1327 gate_filenames.clone(),
1328 vec![],
1329 Some(dtype),
1330 dev,
1331 vec![None],
1332 silent,
1333 None,
1334 |_| true,
1335 Arc::new(|_| DeviceForLoadTensor::Base),
1336 )?;
1337 info!(
1338 "Loaded gating layers from `{}`",
1339 gate_filenames[0].display()
1340 );
1341 Some(vb)
1342 } else {
1343 None
1344 };
1345
1346 self.model.create_anymoe_layers(
1347 vbs.clone(),
1348 config.clone(),
1349 (prefix.clone(), mlp.clone()),
1350 layers.clone(),
1351 expert_type.clone(),
1352 gate_vb.clone(),
1353 )?;
1354
1355 Ok(())
1356 }
1357 fn amoe_supported(&self) -> bool {
1358 self.model.amoe_supported()
1359 }
1360}