1use super::isq::ImatrixDataSource;
2use super::llg::build_llg_factory;
3use super::{
4 get_model_paths, get_xlora_paths, text_models_inputs_processor::ModelInputs, AdapterKind,
5 CacheManager, GeneralMetadata, Loader, ModelKind, ModelPaths, NormalModel, NormalModelLoader,
6 TokenSource,
7};
8use super::{
9 AnyMoePipelineMixin, CacheManagerMixin, EitherCache, ForwardInputsResult, IsqOrganization,
10 IsqPipelineMixin, MetadataMixin, ModelCategory, PreProcessingMixin,
11};
12use super::{
13 AutoNormalLoader, DeepSeekV2Loader, DeepSeekV3Loader, GLM4Loader, Gemma2Loader, GemmaLoader,
14 GptOssLoader, GraniteMoeHybridLoader, LlamaLoader, MistralLoader, MixtralLoader,
15 NormalLoaderType, Phi2Loader, Phi3Loader, Phi3_5MoELoader, Qwen2Loader, Qwen3Loader,
16 Qwen3MoELoader, SmolLm3Loader, Starcoder2Loader,
17};
18use crate::amoe::AnyMoeExpertType;
19use crate::attention::ATTENTION_CHUNK_SIZE;
20use crate::device_map::{self, DeviceMapper};
21use crate::distributed::{self, WorkerTransferData};
22use crate::kv_cache::{FullCacheManager, HybridCacheManager, NormalCacheManager};
23use crate::lora::Ordering;
24use crate::paged_attention::{calculate_cache_config, AttentionImplementation, CacheEngine};
25use crate::pipeline::chat_template::{calculate_eos_tokens, GenerationConfig};
26use crate::pipeline::isq::UqffFullSer;
27use crate::pipeline::loaders::auto_device_map;
28use crate::pipeline::loaders::QuantizationConfigShim;
29use crate::pipeline::sampling::sample_and_add_toks;
30use crate::pipeline::text_models_inputs_processor::make_prompt_chunk;
31use crate::pipeline::{get_chat_template, Modalities, SupportedModality};
32use crate::pipeline::{ChatTemplate, LocalModelPaths};
33use crate::prefix_cacher::PrefixCacheManagerV2;
34use crate::sequence::Sequence;
35use crate::utils::tokenizer::get_tokenizer;
36use crate::utils::varbuilder_utils::DeviceForLoadTensor;
37use crate::utils::{
38 progress::{new_multi_progress, ProgressScopeGuard},
39 tokens::get_token,
40 varbuilder_utils::from_mmaped_safetensors,
41};
42use crate::xlora_models::NonGranularState;
43use crate::{
44 api_dir_list, api_get_file, get_mut_arcmutex, get_paths, get_uqff_paths, lora_model_loader,
45 normal_model_loader, normal_model_loader_sharded, xlora_model_loader, DeviceMapSetting,
46 PagedAttentionConfig, Pipeline, Topology, TryIntoDType, GLOBAL_HF_CACHE,
47};
48use anyhow::Result;
49use candle_core::{Device, Tensor, Var};
50use hf_hub::Cache;
51use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
52use mistralrs_quant::log::once_log_info;
53use mistralrs_quant::{
54 AfqLayer, GgufMatMul, HqqLayer, ImmediateIsqOverride, IsqType, QuantizedSerdeType,
55};
56use rand_isaac::Isaac64Rng;
57use regex_automata::meta::Regex;
58use std::any::Any;
59use std::borrow::Cow;
60use std::path::{Path, PathBuf};
61use std::str::FromStr;
62use std::sync::{Arc, RwLock};
63use std::time::Instant;
64use std::{env, fs};
65use tokenizers::Tokenizer;
66use tokio::sync::Mutex;
67use tracing::{info, warn};
68
69pub struct NormalPipeline {
70 model: Box<dyn NormalModel + Send + Sync>,
71 tokenizer: Arc<Tokenizer>,
72 no_kv_cache: bool,
73 chat_template: Arc<ChatTemplate>,
74 non_granular_state: Option<NonGranularState>,
75 model_id: String,
76 metadata: Arc<GeneralMetadata>,
77 topology: Option<Topology>,
78 silent: bool,
79 organization: IsqOrganization,
80 template_filename: Option<PathBuf>,
82 generation_config: Option<PathBuf>,
83 config: String,
84 imatrix: Option<PathBuf>,
85 mapper: Box<dyn DeviceMapper + Send + Sync>,
86}
87
88pub struct NormalLoader {
90 inner: Box<dyn NormalModelLoader>,
91 model_id: String,
92 config: NormalSpecificConfig,
93 xlora_model_id: Option<String>,
94 lora_adapter_ids: Option<Vec<String>>,
95 kind: ModelKind,
96 xlora_order: Option<Ordering>,
97 no_kv_cache: bool,
98 chat_template: Option<String>,
99 tokenizer_json: Option<String>,
100 tgt_non_granular_index: Option<usize>,
101 token_source: RwLock<Option<TokenSource>>,
102 revision: RwLock<Option<String>>,
103 from_uqff: RwLock<Option<Vec<PathBuf>>>,
104 jinja_explicit: Option<String>,
105 hf_cache_path: Option<PathBuf>,
106}
107
108#[derive(Default)]
109pub struct NormalLoaderBuilder {
111 model_id: Option<String>,
112 config: NormalSpecificConfig,
113 xlora_model_id: Option<String>,
114 lora_adapter_ids: Option<Vec<String>>,
115 kind: ModelKind,
116 xlora_order: Option<Ordering>,
117 no_kv_cache: bool,
118 chat_template: Option<String>,
119 tokenizer_json: Option<String>,
120 tgt_non_granular_index: Option<usize>,
121 jinja_explicit: Option<String>,
122 hf_cache_path: Option<PathBuf>,
123}
124
125#[derive(Clone, Default)]
126pub struct NormalSpecificConfig {
128 pub topology: Option<Topology>,
129 pub organization: IsqOrganization,
130 pub write_uqff: Option<PathBuf>,
131 pub from_uqff: Option<Vec<PathBuf>>,
132 pub imatrix: Option<PathBuf>,
133 pub calibration_file: Option<PathBuf>,
134 pub hf_cache_path: Option<PathBuf>,
135 pub matformer_config_path: Option<PathBuf>,
136 pub matformer_slice_name: Option<String>,
137}
138
139impl NormalLoaderBuilder {
140 pub fn new(
141 config: NormalSpecificConfig,
142 chat_template: Option<String>,
143 tokenizer_json: Option<String>,
144 model_id: Option<String>,
145 no_kv_cache: bool,
146 jinja_explicit: Option<String>,
147 ) -> Self {
148 Self {
149 config,
150 chat_template,
151 tokenizer_json,
152 model_id,
153 kind: ModelKind::Normal,
154 jinja_explicit,
155 no_kv_cache,
156 ..Default::default()
157 }
158 }
159
160 fn with_adapter(
161 mut self,
162 xlora_model_id: String,
163 xlora_order: Ordering,
164 no_kv_cache: bool,
165 tgt_non_granular_index: Option<usize>,
166 ) -> Self {
167 self.xlora_model_id = Some(xlora_model_id);
168 self.xlora_order = Some(xlora_order);
169 self.no_kv_cache = no_kv_cache;
170 self.tgt_non_granular_index = tgt_non_granular_index;
171 self.model_id = if let Some(id) = self.model_id {
172 Some(id)
173 } else {
174 info!(
175 "Using adapter base model ID: `{}`",
176 self.xlora_order.as_ref().unwrap().base_model_id
177 );
178 Some(self.xlora_order.as_ref().unwrap().base_model_id.clone())
179 };
180 self
181 }
182
183 pub fn with_xlora(
184 mut self,
185 xlora_model_id: String,
186 xlora_order: Ordering,
187 no_kv_cache: bool,
188 tgt_non_granular_index: Option<usize>,
189 ) -> Self {
190 self.kind = ModelKind::Adapter {
191 adapter: AdapterKind::XLora,
192 };
193 self.with_adapter(
194 xlora_model_id,
195 xlora_order,
196 no_kv_cache,
197 tgt_non_granular_index,
198 )
199 }
200
201 pub fn with_lora(mut self, lora_adapter_ids: Vec<String>) -> Self {
202 self.kind = ModelKind::Adapter {
203 adapter: AdapterKind::Lora,
204 };
205 self.lora_adapter_ids = Some(lora_adapter_ids);
206 self
207 }
208
209 pub fn hf_cache_path(mut self, hf_cache_path: PathBuf) -> Self {
210 self.hf_cache_path = Some(hf_cache_path);
211 self
212 }
213
214 pub fn build(self, loader_tp: Option<NormalLoaderType>) -> anyhow::Result<Box<dyn Loader>> {
217 let loader: Box<dyn NormalModelLoader> = match loader_tp {
218 Some(NormalLoaderType::Mistral) => Box::new(MistralLoader),
219 Some(NormalLoaderType::Gemma) => Box::new(GemmaLoader),
220 Some(NormalLoaderType::Llama) => Box::new(LlamaLoader),
221 Some(NormalLoaderType::Mixtral) => Box::new(MixtralLoader),
222 Some(NormalLoaderType::Phi2) => Box::new(Phi2Loader),
223 Some(NormalLoaderType::Phi3) => Box::new(Phi3Loader),
224 Some(NormalLoaderType::Qwen2) => Box::new(Qwen2Loader),
225 Some(NormalLoaderType::Gemma2) => Box::new(Gemma2Loader),
226 Some(NormalLoaderType::Starcoder2) => Box::new(Starcoder2Loader),
227 Some(NormalLoaderType::Phi3_5MoE) => Box::new(Phi3_5MoELoader),
228 Some(NormalLoaderType::DeepSeekV2) => Box::new(DeepSeekV2Loader),
229 Some(NormalLoaderType::DeepSeekV3) => Box::new(DeepSeekV3Loader),
230 Some(NormalLoaderType::Qwen3) => Box::new(Qwen3Loader),
231 Some(NormalLoaderType::GLM4) => Box::new(GLM4Loader),
232 Some(NormalLoaderType::Qwen3Moe) => Box::new(Qwen3MoELoader),
233 Some(NormalLoaderType::SmolLm3) => Box::new(SmolLm3Loader),
234 Some(NormalLoaderType::GraniteMoeHybrid) => Box::new(GraniteMoeHybridLoader),
235 Some(NormalLoaderType::GptOss) => Box::new(GptOssLoader),
236 None => Box::new(AutoNormalLoader),
237 };
238 Ok(Box::new(NormalLoader {
239 inner: loader,
240 model_id: self.model_id.unwrap(),
241 config: self.config,
242 xlora_model_id: self.xlora_model_id,
243 lora_adapter_ids: self.lora_adapter_ids,
244 kind: self.kind,
245 xlora_order: self.xlora_order,
246 no_kv_cache: self.no_kv_cache,
247 chat_template: self.chat_template,
248 tokenizer_json: self.tokenizer_json,
249 tgt_non_granular_index: self.tgt_non_granular_index,
250 jinja_explicit: self.jinja_explicit,
251 token_source: RwLock::new(None),
252 revision: RwLock::new(None),
253 from_uqff: RwLock::new(None),
254 hf_cache_path: self.hf_cache_path,
255 }))
256 }
257}
258
259impl Loader for NormalLoader {
260 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
261 fn load_model_from_hf(
262 &self,
263 revision: Option<String>,
264 token_source: TokenSource,
265 dtype: &dyn TryIntoDType,
266 device: &Device,
267 silent: bool,
268 mapper: DeviceMapSetting,
269 in_situ_quant: Option<IsqType>,
270 paged_attn_config: Option<PagedAttentionConfig>,
271 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
272 let _progress_guard = ProgressScopeGuard::new(silent);
273 let cache = self
274 .hf_cache_path
275 .clone()
276 .map(Cache::new)
277 .unwrap_or_default();
278 GLOBAL_HF_CACHE.get_or_init(|| cache);
279
280 let paths: anyhow::Result<Box<dyn ModelPaths>> = get_paths!(
281 LocalModelPaths,
282 &token_source,
283 revision.clone(),
284 self,
285 None,
286 None,
287 silent,
288 self.config.from_uqff.is_some()
289 );
290 if let Some(from_uqff) = self.config.from_uqff.clone() {
291 *self.from_uqff.write().unwrap() = Some(get_uqff_paths!(&from_uqff, self, silent));
292 }
293 *self
294 .token_source
295 .write()
296 .expect("Failed to write to token source") = Some(token_source);
297 *self.revision.write().expect("Failed to write to revision") = revision;
298 self.load_model_from_path(
299 &paths?,
300 dtype,
301 device,
302 silent,
303 mapper,
304 in_situ_quant,
305 paged_attn_config,
306 )
307 }
308
309 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
310 fn load_model_from_path(
311 &self,
312 paths: &Box<dyn ModelPaths>,
313 dtype: &dyn TryIntoDType,
314 device: &Device,
315 silent: bool,
316 mut mapper: DeviceMapSetting,
317 mut in_situ_quant: Option<IsqType>,
318 mut paged_attn_config: Option<PagedAttentionConfig>,
319 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
320 let _progress_guard = ProgressScopeGuard::new(silent);
321 let config = std::fs::read_to_string(paths.get_config_filename())?;
322
323 if !self.inner.supports_paged_attention(&config)? {
324 paged_attn_config = None;
325 }
326
327 info!("Prompt chunk size is {ATTENTION_CHUNK_SIZE}.");
328
329 let use_nccl = mistralrs_quant::distributed::use_nccl();
330
331 let available_devices = if let Ok(payload) = env::var(distributed::IS_DAEMON_FLAG) {
332 let payload: WorkerTransferData = serde_json::from_str(&payload)?;
333 let WorkerTransferData::Init { id: _, worker_rank } = payload;
334 vec![candle_core::Device::new_cuda(worker_rank + 1)?]
335 } else if use_nccl {
336 vec![candle_core::Device::new_cuda(0)?]
337 } else {
338 device_map::get_all_similar_devices(device)?
339 };
340 #[cfg(feature = "cuda")]
341 for device in &available_devices {
342 if let Device::Cuda(dev) = device {
343 unsafe { dev.disable_event_tracking() };
344 }
345 }
346 let device = if use_nccl || cfg!(feature = "ring") {
347 available_devices[0].clone()
348 } else {
349 device.clone()
350 };
351
352 if use_nccl || cfg!(feature = "ring") {
354 mapper = DeviceMapSetting::DummyNccl {
355 nm_device: available_devices[0].clone(),
356 };
357 } else if let DeviceMapSetting::Auto(params) = mapper.clone() {
358 let dtype = dtype.try_into_dtype(&available_devices.iter().collect::<Vec<_>>())?;
360
361 if QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)? != 1 {
363 in_situ_quant = None;
364 }
365
366 let (layer_sizes_in_bytes, non_mapped_size_in_bytes, total_model_size_in_bytes) =
369 if let Some(serialized) = &*self.from_uqff.read().unwrap() {
370 let weight_pack_factor = {
371 let ser_artifacts = unsafe {
372 candle_core::safetensors::MmapedSafetensors::multi(serialized)?
373 };
374 let mut total_pack_factors = 0;
375 let total_tensors = ser_artifacts.tensors().len();
376 for (_, artifact) in ser_artifacts.tensors() {
377 let artifact = artifact.data();
378 let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
380 let pack_factor = match QuantizedSerdeType::try_from(isq_type as usize)?
381 {
382 QuantizedSerdeType::Hqq => {
383 HqqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
384 .pack_factor(dtype)
385 }
386 QuantizedSerdeType::Gguf => {
387 GgufMatMul::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
388 .pack_factor(dtype)
389 }
390 QuantizedSerdeType::Fp8 => IsqType::F8E4M3.pack_factor(dtype),
391 QuantizedSerdeType::Unquant => 1,
392 QuantizedSerdeType::Afq => {
393 AfqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
394 .pack_factor(dtype)
395 }
396 };
397 total_pack_factors += pack_factor;
398 }
399
400 total_pack_factors / total_tensors
401 };
402
403 let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
404 &config,
405 dtype,
406 weight_pack_factor,
407 None,
408 )?;
409 let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
410 &config,
411 dtype,
412 weight_pack_factor,
413 None,
414 )?;
415 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
416 (
417 layer_sizes_in_bytes,
418 non_mapped_size_in_bytes,
419 layer_sizes_sum + non_mapped_size_in_bytes,
420 )
421 } else if let Some(isq) = in_situ_quant {
422 let weight_pack_factor = isq.pack_factor(dtype);
423 let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
424 &config,
425 dtype,
426 weight_pack_factor,
427 None,
428 )?;
429 let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
430 &config,
431 dtype,
432 weight_pack_factor,
433 None,
434 )?;
435 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
436 (
437 layer_sizes_in_bytes,
438 non_mapped_size_in_bytes,
439 layer_sizes_sum + non_mapped_size_in_bytes,
440 )
441 } else {
442 let weight_pack_factor =
444 QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)?;
445 let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
446 &config,
447 dtype,
448 weight_pack_factor,
449 None,
450 )?;
451 let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
452 &config,
453 dtype,
454 weight_pack_factor,
455 None,
456 )?;
457 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
458 (
459 layer_sizes_in_bytes,
460 non_mapped_size_in_bytes,
461 layer_sizes_sum + non_mapped_size_in_bytes,
462 )
463 };
464
465 let new = auto_device_map::get_device_layers(
466 &*self.inner,
467 &config,
468 self.inner.num_layers(&config)?,
469 layer_sizes_in_bytes,
470 non_mapped_size_in_bytes,
471 total_model_size_in_bytes,
472 &available_devices,
473 dtype,
474 ¶ms,
475 paged_attn_config.as_ref(),
476 )?;
477 mapper = DeviceMapSetting::Map(new);
478 }
479
480 let pipeline_mapper = mapper.into_mapper(
481 self.inner.num_layers(&config)?,
482 &device,
483 self.config.topology.as_ref(),
484 )?;
485 let mapper = mapper.into_mapper(
486 self.inner.num_layers(&config)?,
487 &device,
488 self.config.topology.as_ref(),
489 )?;
490 let mut layer_devices = Vec::new();
491 for layer in 0..self.inner.num_layers(&config)? {
492 let device = mapper.device_for(layer, false).cloned();
493 layer_devices.push(device);
494 }
495 let dtype = mapper.get_min_dtype(dtype)?;
496
497 let mapping_uses_cpu = mapper.get_unique_devices().iter().any(Device::is_cpu);
500 if mapping_uses_cpu && paged_attn_config.is_some() {
501 warn!("Device mapping contains a mix of GPU and CPU. There is no CPU support for PagedAttention, disabling PagedAttention.");
502 paged_attn_config = None;
503 }
504
505 info!("Model config: {:?}", self.inner.get_config_repr(&config)?);
506 if crate::using_flash_attn() {
507 once_log_info("FlashAttention is enabled.");
508 }
509
510 let topology_overrides = self
511 .config
512 .topology
513 .as_ref()
514 .map(|topology| {
515 topology
516 .pattern_overrides()
517 .into_iter()
518 .map(|(regex, layer)| ImmediateIsqOverride {
519 predicate: regex,
520 ty: layer.isq,
521 device: layer.device.clone(),
522 })
523 .collect::<Vec<_>>()
524 })
525 .unwrap_or_default();
526 let has_override_isq = topology_overrides
527 .iter()
528 .any(|override_entry| override_entry.ty.is_some());
529 let topology_requires_post_quant = self
530 .config
531 .topology
532 .as_ref()
533 .is_some_and(|topology| topology.requires_post_quantization());
534
535 let allow_immediate_cli = self.config.imatrix.is_none()
536 && self.config.calibration_file.is_none()
537 && !device.is_cuda()
538 && in_situ_quant.is_some();
539
540 let mut immediate_ty = None;
541 let mut immediate_predicates = Vec::new();
542 if allow_immediate_cli {
543 immediate_ty = in_situ_quant;
544 immediate_predicates =
545 if matches!(self.config.organization, IsqOrganization::MoeExpertsOnly) {
546 self.inner.immediate_isq_predicates_moqe(&config)?
547 } else {
548 self.inner.immediate_isq_predicates(&config)?
549 };
550 info!("Applying ISQ to {in_situ_quant:?}");
551 if immediate_predicates.is_empty() {
552 warn!("No predicates for this model and ISQ setting detected. ISQ will not be applied to any weights!");
553 }
554 }
555
556 let use_immediate = allow_immediate_cli || has_override_isq;
557 if use_immediate {
558 mistralrs_quant::set_immediate_isq_with_overrides(
559 immediate_ty,
560 immediate_predicates.clone(),
561 topology_overrides.clone(),
562 );
563 }
564
565 let mut loading_isq = if use_immediate {
567 false
568 } else {
569 in_situ_quant.is_some()
570 };
571 if self.config.imatrix.is_some() || self.config.calibration_file.is_some() {
572 loading_isq = true;
573 }
574 loading_isq |= topology_requires_post_quant;
575
576 if self.config.imatrix.is_some() && self.config.calibration_file.is_some() {
577 anyhow::bail!(
578 "`imatrix` and `calibration_file` were both specified, this is not allowed."
579 );
580 }
581
582 let load_device = if !loading_isq || self.config.calibration_file.is_some() {
584 loading_isq = false;
585 device.clone()
586 } else {
587 Device::Cpu
588 };
589
590 let is_xlora = self.kind.is_adapted_and(|a| a.is_x_lora());
591
592 let attention_mechanism = if paged_attn_config.is_some() {
593 AttentionImplementation::PagedAttention
594 } else {
595 AttentionImplementation::Eager
596 };
597
598 let multi_progress = Arc::new(new_multi_progress());
599
600 let matformer_slicing_config = if let Some(matformer_path) =
602 &self.config.matformer_config_path
603 {
604 use crate::matformer::{MatformerConfig, MatformerSliceConfig};
605 info!("Loading Matformer config from {:?}", matformer_path);
606 let config = Arc::new(MatformerConfig::from_file(matformer_path)?);
607
608 if let Some(slice_name) = &self.config.matformer_slice_name {
609 info!("Using Matformer slice: {}", slice_name);
610 Some(MatformerSliceConfig::new(slice_name.clone(), config))
611 } else {
612 warn!("Matformer config loaded but no slice name specified. Models will use their default slice.");
615 None
616 }
617 } else {
618 None
619 };
620
621 let mut model = if use_nccl || cfg!(feature = "ring") {
622 let (mapper, sharded_vb) = distributed::prepare_distributed_mapper(
623 dtype,
624 &device,
625 &available_devices,
626 silent,
627 &config,
628 loading_isq,
629 self.config.from_uqff.is_some(),
630 self.config.organization,
631 &*self.inner,
632 paths.as_ref(),
633 )?;
634
635 match self.kind {
637 ModelKind::Normal => normal_model_loader_sharded!(
638 sharded_vb,
639 config,
640 self.inner,
641 mapper,
642 loading_isq,
643 device.clone(),
644 attention_mechanism,
645 multi_progress.clone(),
646 matformer_slicing_config.clone(),
647 ),
648 ModelKind::Adapter {
649 adapter: AdapterKind::XLora,
650 } => xlora_model_loader!(
651 paths,
652 Some(dtype),
653 &load_device,
654 layer_devices.clone(),
655 config,
656 self.inner,
657 silent,
658 mapper,
659 loading_isq,
660 device.clone(),
661 multi_progress.clone(),
662 matformer_slicing_config.clone(),
663 ),
664 ModelKind::Adapter {
665 adapter: AdapterKind::Lora,
666 } => lora_model_loader!(
667 paths,
668 Some(dtype),
669 &load_device,
670 layer_devices.clone(),
671 config,
672 self.inner,
673 silent,
674 mapper,
675 loading_isq,
676 self.config.from_uqff.is_some(),
677 device.clone(),
678 attention_mechanism,
679 matches!(self.config.organization, IsqOrganization::MoeExpertsOnly),
680 multi_progress.clone(),
681 matformer_slicing_config.clone(),
682 ),
683 _ => unreachable!(),
684 }
685 } else {
686 match self.kind {
687 ModelKind::Normal => normal_model_loader!(
688 paths,
689 Some(dtype),
690 &load_device,
691 layer_devices.clone(),
692 config,
693 self.inner,
694 silent,
695 mapper,
696 loading_isq,
697 self.config.from_uqff.is_some(),
698 device.clone(),
699 attention_mechanism,
700 matches!(self.config.organization, IsqOrganization::MoeExpertsOnly),
701 multi_progress.clone(),
702 matformer_slicing_config.clone(),
703 ),
704 ModelKind::Adapter {
705 adapter: AdapterKind::XLora,
706 } => xlora_model_loader!(
707 paths,
708 Some(dtype),
709 &load_device,
710 layer_devices.clone(),
711 config,
712 self.inner,
713 silent,
714 mapper,
715 loading_isq,
716 device.clone(),
717 multi_progress.clone(),
718 matformer_slicing_config.clone(),
719 ),
720 ModelKind::Adapter {
721 adapter: AdapterKind::Lora,
722 } => lora_model_loader!(
723 paths,
724 Some(dtype),
725 &load_device,
726 layer_devices.clone(),
727 config,
728 self.inner,
729 silent,
730 mapper,
731 loading_isq,
732 self.config.from_uqff.is_some(),
733 device.clone(),
734 attention_mechanism,
735 matches!(self.config.organization, IsqOrganization::MoeExpertsOnly),
736 multi_progress.clone(),
737 matformer_slicing_config.clone(),
738 ),
739 _ => unreachable!(),
740 }
741 };
742
743 let tokenizer = get_tokenizer(paths.get_tokenizer_filename(), None)?;
744 let gen_conf: Option<GenerationConfig> = paths.get_gen_conf_filename().and_then(|f| {
745 match serde_json::from_str::<GenerationConfig>(&fs::read_to_string(f).unwrap()) {
746 Ok(conf) => Some(conf),
747 Err(e) => {
748 warn!("Failed to parse generation_config.json: {}", e);
749 None
750 }
751 }
752 });
753
754 let chat_template_explicit = paths
755 .get_chat_template_explicit()
756 .as_ref()
757 .map(|x| x.to_string_lossy().to_string());
758 let chat_template = get_chat_template(
759 paths,
760 self.jinja_explicit.as_ref(),
761 chat_template_explicit.as_ref(),
762 self.chat_template.as_ref(),
763 None,
764 );
765
766 if let Some(calibration_file) = &self.config.calibration_file {
767 let calibration_data = std::fs::read_to_string(calibration_file)?;
768 let tokens = tokenizer
770 .encode_fast(calibration_data, false)
771 .map_err(anyhow::Error::msg)?
772 .get_ids()
773 .to_vec();
774 info!(
775 "Collecting imatrix from calibration file `{}` of {} tokens.",
776 calibration_file.display(),
777 tokens.len()
778 );
779 let bos_toks = chat_template.bos_tok().map(|b| vec![b]).unwrap_or_default();
780 let bos_tok_id = tokenizer
781 .token_to_id(&bos_toks[0])
782 .expect("Somehow the bos token is not present.");
783
784 match self.config.organization {
785 IsqOrganization::Default => model.begin_track_stats()?,
786 IsqOrganization::MoeExpertsOnly => model.begin_track_stats_moe_experts_only()?,
787 }
788
789 const CHUNK_SIZE: usize = 1024;
790 let n_chunks = tokens.len().div_ceil(CHUNK_SIZE);
791 let start = Instant::now();
792 for (i, chunk) in tokens.chunks(CHUNK_SIZE).enumerate() {
793 let chunk = [vec![bos_tok_id], chunk.to_vec()].concat();
794 let chunk_len = chunk.len();
795
796 let start = Instant::now();
797 let inputs = make_prompt_chunk(
798 0,
799 vec![&chunk],
800 &[0],
801 &load_device,
802 None,
803 false,
804 None,
805 Some(pipeline_mapper.as_ref()),
806 )?;
807
808 model.forward(
809 &inputs.input.to_device(model.device())?,
810 &inputs.positions,
811 inputs.context_lens.clone(),
812 inputs.position_ids.clone(),
813 None,
814 &inputs.flash_meta.clone(),
815 )?;
816
817 match model.cache_mut() {
818 EitherCache::Full(full) => {
819 for layer in &mut *full.lock() {
820 *layer = None
821 }
822 }
823 EitherCache::Normal(normal) => {
824 for layer in &mut *normal.lock().unwrap().0 {
825 layer.reset();
826 }
827 }
828 EitherCache::Hybrid(hybrid) => {
829 hybrid.lock().unwrap().reset();
830 }
831 }
832
833 let end = Instant::now();
834 info!(
835 "Processed chunk {}/{n_chunks} ({chunk_len} tokens), {:.2}s",
836 i + 1,
837 end.duration_since(start).as_secs_f32()
838 );
839 }
840 load_device.synchronize()?;
841 let end = Instant::now();
842 info!(
843 "Finished collecting imatrix in {:.2}s",
844 end.duration_since(start).as_secs_f32()
845 );
846 }
847
848 let should_serialize = self.config.write_uqff.is_some();
850 let should_quantize_pass = loading_isq;
851
852 if (should_quantize_pass || should_serialize) && self.config.from_uqff.is_none() {
853 let imatrix_source = if should_quantize_pass {
854 match (
855 self.config.imatrix.as_ref(),
856 self.config.calibration_file.is_some(),
857 ) {
858 (None, false) => None,
859 (Some(file), false) => Some(ImatrixDataSource::File(file)),
860 (None, true) => Some(ImatrixDataSource::Collected),
861 (Some(_), true) => unreachable!(),
862 }
863 } else {
864 None
865 };
866
867 if should_quantize_pass {
868 info!("Applying ISQ to all ranks.");
869 } else {
870 info!("Serializing existing ISQ tensors without additional quantization.");
871 }
872
873 let multi_progress = Arc::new(new_multi_progress());
874
875 model.quantize(
876 in_situ_quant,
877 model.device().clone(),
878 self.config.topology.as_ref(),
879 silent,
880 imatrix_source,
881 self.config.organization,
882 should_quantize_pass,
883 self.config.write_uqff.as_ref(),
884 UqffFullSer {
885 tokenizer: &tokenizer,
886 template_filename: paths.get_template_filename(),
887 generation_config: paths.get_gen_conf_filename(),
888 config: config.clone(),
889 processor_filename: &None,
890 preprocessor_filename: &None,
891 modules: None,
892 module_paths: None,
893 },
894 multi_progress.clone(),
895 )?;
896 } else if let Some(from_uqff) = &*self.from_uqff.read().unwrap() {
897 model.load_from_artifacts(
898 device.clone(),
899 self.config.topology.as_ref(),
900 silent,
901 from_uqff,
902 )?;
903 }
904
905 let paged_attn_config = if matches!(
906 self.kind,
907 ModelKind::Adapter {
908 adapter: AdapterKind::XLora
909 }
910 ) {
911 warn!(
912 "Adapter parallel_models do not currently support PagedAttention, running without"
913 );
914 None
915 } else {
916 paged_attn_config
917 };
918
919 let (cache_config, cache_engine) = if let Some(paged_attn_config) = paged_attn_config {
920 let cache_config = calculate_cache_config(
921 paged_attn_config.mem_gpu,
922 paged_attn_config.block_size,
923 dtype,
924 paged_attn_config.cache_type,
925 model.config(),
926 &device,
927 &pipeline_mapper
928 .get_unique_devices()
929 .into_iter()
930 .map(Some)
931 .collect::<Vec<_>>(),
932 silent,
933 )?;
934
935 let mut layer_devices = Vec::new();
936 for layer in 0..self.inner.num_layers(&config)? {
937 let device = model.get_layers().1.device_for(layer, false).cloned();
938 layer_devices.push(device);
939 }
940 let cache_engine = CacheEngine::new(
941 model.config(),
942 &cache_config,
943 dtype,
944 model.device(),
945 layer_devices.clone(),
946 )?;
947
948 (Some(cache_config), Some(cache_engine))
949 } else {
950 (None, None)
951 };
952
953 let max_seq_len = model.max_seq_len();
954 let llg_factory = build_llg_factory(tokenizer.clone())?;
955 let num_hidden_layers = match model.cache() {
956 EitherCache::Full(full) => full.lock().len(),
957 EitherCache::Normal(normal) => normal.lock().unwrap().0.len(),
958 EitherCache::Hybrid(hybrid) => hybrid.lock().unwrap().num_layers(),
959 };
960 let eos = calculate_eos_tokens(&chat_template, gen_conf, &tokenizer);
961 let sliding_window = model.config().sliding_window;
962 let model_metadata = Arc::new(model.config().clone());
963
964 Ok(Arc::new(Mutex::new(NormalPipeline {
965 model,
966 tokenizer: tokenizer.into(),
967 no_kv_cache: self.no_kv_cache,
968 chat_template: Arc::new(chat_template),
969 non_granular_state: self.tgt_non_granular_index.map(|tgt_non_granular_index| {
970 NonGranularState {
971 non_granular_index: Arc::new(Mutex::new(0)),
972 tgt_non_granular_index,
973 }
974 }),
975 model_id: self.model_id.clone(),
976 metadata: Arc::new(GeneralMetadata {
977 max_seq_len,
978 llg_factory: Some(llg_factory),
979 no_kv_cache: self.no_kv_cache,
980 no_prefix_cache: is_xlora,
981 num_hidden_layers,
982 eos_tok: eos,
983 kind: self.kind.clone(),
984 is_xlora,
985 activation_dtype: dtype,
986 sliding_window,
987 cache_config,
988 cache_engine,
989 model_metadata: Some(model_metadata),
990 modalities: Modalities {
991 input: vec![SupportedModality::Text],
992 output: vec![SupportedModality::Text],
993 },
994 }),
995 topology: self.config.topology.clone(),
996 silent,
997 organization: self.config.organization,
998 template_filename: paths.get_template_filename().clone(),
999 generation_config: paths.get_gen_conf_filename().cloned(),
1000 config,
1001 imatrix: self.config.imatrix.clone(),
1002 mapper: pipeline_mapper,
1003 })))
1004 }
1005
1006 fn get_id(&self) -> String {
1007 self.model_id.clone()
1008 }
1009
1010 fn get_kind(&self) -> ModelKind {
1011 self.kind.clone()
1012 }
1013}
1014
1015impl PreProcessingMixin for NormalPipeline {
1016 fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
1017 Some(self.chat_template.clone())
1018 }
1019 fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
1020 None
1021 }
1022}
1023
1024impl IsqPipelineMixin for NormalPipeline {
1025 fn re_isq_model(&mut self, dtype: IsqType) -> Result<()> {
1026 let device = self.device().clone();
1027 let multi_progress = Arc::new(new_multi_progress());
1028 self.model.quantize(
1029 Some(dtype),
1030 device.clone(),
1031 self.topology.as_ref(),
1032 self.silent,
1033 self.imatrix.as_ref().map(ImatrixDataSource::File),
1034 self.organization,
1035 true,
1036 None,
1037 UqffFullSer {
1038 tokenizer: &self.tokenizer,
1039 template_filename: &self.template_filename,
1040 generation_config: self.generation_config.as_ref(),
1041 config: self.config.clone(),
1042 processor_filename: &None,
1043 preprocessor_filename: &None,
1044 modules: None,
1045 module_paths: None,
1046 },
1047 multi_progress.clone(),
1048 )?;
1049 Ok(())
1050 }
1051}
1052
1053impl CacheManagerMixin for NormalPipeline {
1054 fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
1055 match self.model.cache() {
1056 EitherCache::Full(_) => FullCacheManager.clone_in_cache(self, seqs, false),
1057 EitherCache::Normal(_) => NormalCacheManager.clone_in_cache(self, seqs, false),
1058 EitherCache::Hybrid(_) => HybridCacheManager.clone_in_cache(self, seqs, false),
1059 }
1060 }
1061 fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
1062 match self.model.cache() {
1063 EitherCache::Full(_) => FullCacheManager.clone_out_cache(self, seqs, false),
1064 EitherCache::Normal(_) => NormalCacheManager.clone_out_cache(self, seqs, false),
1065 EitherCache::Hybrid(_) => HybridCacheManager.clone_out_cache(self, seqs, false),
1066 }
1067 }
1068 fn set_none_cache(
1069 &self,
1070 seqs: &mut [&mut Sequence],
1071 reset_non_granular: bool,
1072 modify_draft_cache: bool,
1073 load_preallocated_cache: bool,
1074 ) {
1075 match self.model.cache() {
1076 EitherCache::Full(_) => {
1077 FullCacheManager.set_none_cache(self, seqs, modify_draft_cache, false)
1078 }
1079 EitherCache::Normal(_) => NormalCacheManager.set_none_cache(
1080 self,
1081 seqs,
1082 modify_draft_cache,
1083 load_preallocated_cache,
1084 ),
1085 EitherCache::Hybrid(_) => HybridCacheManager.set_none_cache(
1086 self,
1087 seqs,
1088 modify_draft_cache,
1089 load_preallocated_cache,
1090 ),
1091 }
1092 if reset_non_granular {
1093 self.reset_non_granular_state()
1094 }
1095 }
1096 fn cache(&self) -> &EitherCache {
1097 self.model.cache()
1098 }
1099}
1100
1101impl MetadataMixin for NormalPipeline {
1102 fn device(&self) -> Device {
1103 self.model.device().clone()
1104 }
1105 fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
1106 Some(self.tokenizer.clone())
1107 }
1108 fn name(&self) -> String {
1109 self.model_id.clone()
1110 }
1111 fn reset_non_granular_state(&self) {
1112 if let Some(s) = self.non_granular_state.as_ref() {
1113 *self.cache().full().get_scalings_cache() = None;
1114 *get_mut_arcmutex!(s.non_granular_index) = 0;
1115 }
1116 }
1117 fn get_metadata(&self) -> Arc<GeneralMetadata> {
1118 self.metadata.clone()
1119 }
1120 fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
1121 Some(&*self.mapper)
1122 }
1123}
1124
1125#[async_trait::async_trait]
1126impl Pipeline for NormalPipeline {
1127 fn forward_inputs(
1128 &mut self,
1129 inputs: Box<dyn Any>,
1130 return_raw_logits: bool,
1131 ) -> Result<ForwardInputsResult, candle_core::Error> {
1132 let ModelInputs {
1133 input_ids,
1134 input_ids_full,
1135 seqlen_offsets,
1136 seqlen_offsets_full,
1137 context_lens,
1138 position_ids,
1139 paged_attn_meta,
1140 flash_meta,
1141 flash_meta_full,
1142 } = *inputs.downcast().expect("Downcast failed.");
1143 let metadata = self.get_metadata();
1144 let paged_attn_meta = match (&metadata.cache_engine, &paged_attn_meta) {
1145 (Some(cache_engine), Some(meta)) => Some((cache_engine, meta)),
1146 (Some(_), None) => {
1147 candle_core::bail!("Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.")
1149 }
1150 (None, Some(_)) => {
1151 candle_core::bail!("Forward step got a PagedAttention input metadata but there is no cache engine. Please raise an issue.")
1153 }
1154 (None, None) => None,
1155 };
1156 let logits = match self.model.is_xlora() {
1157 false => {
1158 let paged_attn_meta = paged_attn_meta
1159 .as_ref()
1160 .map(|meta| (meta.0.get_kv_cache().clone(), meta.1.clone()));
1161
1162 self.model.forward(
1163 &input_ids,
1164 &seqlen_offsets,
1165 context_lens,
1166 position_ids,
1167 paged_attn_meta.as_ref().map(|(a, b)| (a.clone(), b)),
1168 &flash_meta,
1169 )?
1170 }
1171 true => self.model.xlora_forward(
1172 &input_ids,
1173 input_ids_full.as_ref().unwrap_or(&input_ids),
1174 &seqlen_offsets,
1175 seqlen_offsets_full.as_ref().unwrap_or(&seqlen_offsets),
1176 self.no_kv_cache,
1177 &self.non_granular_state,
1178 context_lens,
1179 position_ids,
1180 &flash_meta,
1181 flash_meta_full.as_ref().unwrap_or(&flash_meta),
1182 )?,
1183 };
1184 if return_raw_logits {
1185 Ok(ForwardInputsResult::RawLogits { logits })
1186 } else {
1187 Ok(ForwardInputsResult::CausalGeneration { logits })
1188 }
1189 }
1190 async fn sample_causal_gen(
1191 &self,
1192 seqs: &mut [&mut Sequence],
1193 logits: Vec<Tensor>,
1194 prefix_cacher: &mut PrefixCacheManagerV2,
1195 disable_eos_stop: bool,
1196 rng: Arc<std::sync::Mutex<Isaac64Rng>>,
1197 ) -> Result<(), candle_core::Error> {
1198 sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
1199 }
1200 fn category(&self) -> ModelCategory {
1201 ModelCategory::Text
1202 }
1203}
1204
1205impl AnyMoePipelineMixin for NormalPipeline {
1206 fn amoe_finish_training(&mut self, gate_model_id: Option<String>) -> candle_core::Result<()> {
1207 self.model.finish_training(gate_model_id)
1208 }
1209 fn amoe_layer_vars(&self) -> Vec<Vec<Var>> {
1210 self.model.get_vars()
1211 }
1212 fn amoe_base_model_trainable_params(&self) -> usize {
1213 self.model.trainable_params()
1214 }
1215 fn amoe_take_cached_gating_outputs(&mut self) -> Vec<Tensor> {
1216 self.model.take_cached_gating_outputs()
1217 }
1218 fn amoe_create_layers(
1219 &mut self,
1220 model_ids: Vec<String>,
1221 token: &TokenSource,
1222 revision: Option<String>,
1223 match_regex: &str,
1224 config: crate::amoe::AnyMoeConfig,
1225 dtype: candle_core::DType,
1226 dev: &Device,
1227 (prefix, mlp): (String, String),
1228 layers: Vec<usize>,
1229 expert_type: AnyMoeExpertType,
1230 silent: bool,
1231 gate_model_id: Option<String>,
1232 ) -> candle_core::Result<()> {
1233 let mut vbs = Vec::new();
1234 let regex = Regex::new(match_regex).map_err(candle_core::Error::msg)?;
1236 for model_id in model_ids {
1237 let model_id_str = &model_id;
1238 let model_id = Path::new(&model_id);
1239
1240 let api = {
1241 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1242 let mut api = ApiBuilder::from_cache(cache)
1243 .with_progress(!silent)
1244 .with_token(get_token(token).map_err(candle_core::Error::msg)?);
1245 if let Ok(x) = std::env::var("HF_HUB_CACHE") {
1246 api = api.with_cache_dir(x.into());
1247 }
1248 api.build().map_err(candle_core::Error::msg)?
1249 };
1250 let revision = revision.clone().unwrap_or("main".to_string());
1251 let api = api.repo(Repo::with_revision(
1252 model_id_str.clone(),
1253 RepoType::Model,
1254 revision.clone(),
1255 ));
1256
1257 let mut filenames = vec![];
1258 for rfilename in
1259 api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors"))
1260 {
1261 filenames.push(api_get_file!(api, &rfilename, model_id));
1262 }
1263
1264 let regex = regex.clone();
1265 let match_regex_clone = match_regex.to_string();
1266 let layers_clone = layers.clone();
1267 let vb = from_mmaped_safetensors(
1268 filenames,
1269 vec![],
1270 Some(dtype),
1271 dev,
1272 vec![None],
1273 silent,
1274 None,
1275 move |key| {
1276 if regex.is_match(&key) {
1277 let last_layer_idx = key.find(&match_regex_clone).unwrap() - 1;
1280 let first_layer_idx = key[..last_layer_idx].rfind('.').unwrap();
1281 let layer_n = key[first_layer_idx + 1..last_layer_idx]
1282 .parse::<usize>()
1283 .unwrap();
1284 layers_clone.contains(&layer_n) || layers_clone.is_empty()
1285 } else {
1286 false
1287 }
1288 },
1289 Arc::new(|_| DeviceForLoadTensor::Base),
1290 )?;
1291 vbs.push(vb);
1292 }
1293
1294 let gate_vb = if let Some(gate_model_id) = gate_model_id {
1295 let model_id_str = &gate_model_id;
1296 let model_id = Path::new(&gate_model_id);
1297
1298 let api = {
1299 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1300 let mut api = ApiBuilder::from_cache(cache)
1301 .with_progress(!silent)
1302 .with_token(get_token(token).map_err(candle_core::Error::msg)?);
1303 if let Ok(x) = std::env::var("HF_HUB_CACHE") {
1304 api = api.with_cache_dir(x.into());
1305 }
1306 api.build().map_err(candle_core::Error::msg)?
1307 };
1308 let revision = revision.clone().unwrap_or("main".to_string());
1309 let api = api.repo(Repo::with_revision(
1310 model_id_str.clone(),
1311 RepoType::Model,
1312 revision.clone(),
1313 ));
1314
1315 let mut gate_filenames = vec![];
1316 for rfilename in
1317 api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors"))
1318 {
1319 gate_filenames.push(api_get_file!(api, &rfilename, model_id));
1320 }
1321 assert_eq!(
1322 gate_filenames.len(),
1323 1,
1324 "Gate model ID must contain only one .safetensors file"
1325 );
1326
1327 let vb = from_mmaped_safetensors(
1328 gate_filenames.clone(),
1329 vec![],
1330 Some(dtype),
1331 dev,
1332 vec![None],
1333 silent,
1334 None,
1335 |_| true,
1336 Arc::new(|_| DeviceForLoadTensor::Base),
1337 )?;
1338 info!(
1339 "Loaded gating layers from `{}`",
1340 gate_filenames[0].display()
1341 );
1342 Some(vb)
1343 } else {
1344 None
1345 };
1346
1347 self.model.create_anymoe_layers(
1348 vbs.clone(),
1349 config.clone(),
1350 (prefix.clone(), mlp.clone()),
1351 layers.clone(),
1352 expert_type.clone(),
1353 gate_vb.clone(),
1354 )?;
1355
1356 Ok(())
1357 }
1358 fn amoe_supported(&self) -> bool {
1359 self.model.amoe_supported()
1360 }
1361}