1use super::inputs_processor::DEFAULT_PROMPT_CHUNK_SIZE;
2use super::isq::ImatrixDataSource;
3use super::llg::build_llg_factory;
4use super::{
5 get_model_paths, get_xlora_paths, text_models_inputs_processor::ModelInputs, AdapterKind,
6 CacheManager, GeneralMetadata, Loader, ModelKind, ModelPaths, NormalModel, NormalModelLoader,
7 TokenSource,
8};
9use super::{
10 AnyMoePipelineMixin, CacheManagerMixin, EitherCache, ForwardInputsResult, IsqOrganization,
11 IsqPipelineMixin, MetadataMixin, ModelCategory, PreProcessingMixin,
12};
13use super::{
14 AutoNormalLoader, DeepSeekV2Loader, DeepSeekV3Loader, GLM4Loader, Gemma2Loader, GemmaLoader,
15 LlamaLoader, MistralLoader, MixtralLoader, NormalLoaderType, Phi2Loader, Phi3Loader,
16 Phi3_5MoELoader, Qwen2Loader, Qwen3Loader, Qwen3MoELoader, SmolLm3Loader, Starcoder2Loader,
17};
18use crate::amoe::AnyMoeExpertType;
19use crate::device_map::{self, DeviceMapper};
20use crate::distributed::{self, WorkerTransferData};
21use crate::kv_cache::{FullCacheManager, NormalCacheManager};
22use crate::lora::Ordering;
23use crate::paged_attention::{calculate_cache_config, AttentionImplementation, CacheEngine};
24use crate::pipeline::chat_template::{calculate_eos_tokens, GenerationConfig};
25use crate::pipeline::isq::UqffFullSer;
26use crate::pipeline::loaders::auto_device_map;
27use crate::pipeline::loaders::QuantizationConfigShim;
28use crate::pipeline::sampling::sample_and_add_toks;
29use crate::pipeline::text_models_inputs_processor::make_prompt_chunk;
30use crate::pipeline::{get_chat_template, Modalities, SupportedModality};
31use crate::pipeline::{ChatTemplate, LocalModelPaths};
32use crate::prefix_cacher::PrefixCacheManagerV2;
33use crate::sequence::Sequence;
34use crate::utils::tokenizer::get_tokenizer;
35use crate::utils::varbuilder_utils::DeviceForLoadTensor;
36use crate::utils::{tokens::get_token, varbuilder_utils::from_mmaped_safetensors};
37use crate::xlora_models::NonGranularState;
38use crate::{
39 api_dir_list, api_get_file, get_mut_arcmutex, get_paths, get_uqff_paths, lora_model_loader,
40 normal_model_loader, normal_model_loader_sharded, xlora_model_loader, DeviceMapSetting,
41 PagedAttentionConfig, Pipeline, Topology, TryIntoDType, GLOBAL_HF_CACHE,
42};
43use anyhow::Result;
44use candle_core::{Device, Tensor, Var};
45use hf_hub::Cache;
46use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
47use indicatif::MultiProgress;
48use mistralrs_quant::log::once_log_info;
49use mistralrs_quant::{AfqLayer, GgufMatMul, HqqLayer, IsqType, QuantizedSerdeType};
50use rand_isaac::Isaac64Rng;
51use regex_automata::meta::Regex;
52use std::any::Any;
53use std::borrow::Cow;
54use std::num::{NonZero, NonZeroUsize};
55use std::path::{Path, PathBuf};
56use std::str::FromStr;
57use std::sync::{Arc, RwLock};
58use std::time::Instant;
59use std::{env, fs};
60use tokenizers::Tokenizer;
61use tokio::sync::Mutex;
62use tracing::{info, warn};
63
64pub struct NormalPipeline {
65 model: Box<dyn NormalModel + Send + Sync>,
66 tokenizer: Arc<Tokenizer>,
67 no_kv_cache: bool,
68 chat_template: Arc<ChatTemplate>,
69 non_granular_state: Option<NonGranularState>,
70 model_id: String,
71 metadata: Arc<GeneralMetadata>,
72 topology: Option<Topology>,
73 silent: bool,
74 organization: IsqOrganization,
75 template_filename: Option<PathBuf>,
77 generation_config: Option<PathBuf>,
78 config: String,
79 imatrix: Option<PathBuf>,
80 mapper: Box<dyn DeviceMapper + Send + Sync>,
81}
82
83pub struct NormalLoader {
85 inner: Box<dyn NormalModelLoader>,
86 model_id: String,
87 config: NormalSpecificConfig,
88 xlora_model_id: Option<String>,
89 lora_adapter_ids: Option<Vec<String>>,
90 kind: ModelKind,
91 xlora_order: Option<Ordering>,
92 no_kv_cache: bool,
93 chat_template: Option<String>,
94 tokenizer_json: Option<String>,
95 tgt_non_granular_index: Option<usize>,
96 token_source: RwLock<Option<TokenSource>>,
97 revision: RwLock<Option<String>>,
98 from_uqff: RwLock<Option<Vec<PathBuf>>>,
99 jinja_explicit: Option<String>,
100 hf_cache_path: Option<PathBuf>,
101}
102
103#[derive(Default)]
104pub struct NormalLoaderBuilder {
106 model_id: Option<String>,
107 config: NormalSpecificConfig,
108 xlora_model_id: Option<String>,
109 lora_adapter_ids: Option<Vec<String>>,
110 kind: ModelKind,
111 xlora_order: Option<Ordering>,
112 no_kv_cache: bool,
113 chat_template: Option<String>,
114 tokenizer_json: Option<String>,
115 tgt_non_granular_index: Option<usize>,
116 jinja_explicit: Option<String>,
117 hf_cache_path: Option<PathBuf>,
118}
119
120#[derive(Clone, Default)]
121pub struct NormalSpecificConfig {
123 pub prompt_chunksize: Option<NonZeroUsize>,
124 pub topology: Option<Topology>,
125 pub organization: IsqOrganization,
126 pub write_uqff: Option<PathBuf>,
127 pub from_uqff: Option<Vec<PathBuf>>,
128 pub imatrix: Option<PathBuf>,
129 pub calibration_file: Option<PathBuf>,
130 pub hf_cache_path: Option<PathBuf>,
131 pub matformer_config_path: Option<PathBuf>,
132 pub matformer_slice_name: Option<String>,
133}
134
135impl NormalLoaderBuilder {
136 pub fn new(
137 config: NormalSpecificConfig,
138 chat_template: Option<String>,
139 tokenizer_json: Option<String>,
140 model_id: Option<String>,
141 no_kv_cache: bool,
142 jinja_explicit: Option<String>,
143 ) -> Self {
144 Self {
145 config,
146 chat_template,
147 tokenizer_json,
148 model_id,
149 kind: ModelKind::Normal,
150 jinja_explicit,
151 no_kv_cache,
152 ..Default::default()
153 }
154 }
155
156 fn with_adapter(
157 mut self,
158 xlora_model_id: String,
159 xlora_order: Ordering,
160 no_kv_cache: bool,
161 tgt_non_granular_index: Option<usize>,
162 ) -> Self {
163 self.xlora_model_id = Some(xlora_model_id);
164 self.xlora_order = Some(xlora_order);
165 self.no_kv_cache = no_kv_cache;
166 self.tgt_non_granular_index = tgt_non_granular_index;
167 self.model_id = if let Some(id) = self.model_id {
168 Some(id)
169 } else {
170 info!(
171 "Using adapter base model ID: `{}`",
172 self.xlora_order.as_ref().unwrap().base_model_id
173 );
174 Some(self.xlora_order.as_ref().unwrap().base_model_id.clone())
175 };
176 self
177 }
178
179 pub fn with_xlora(
180 mut self,
181 xlora_model_id: String,
182 xlora_order: Ordering,
183 no_kv_cache: bool,
184 tgt_non_granular_index: Option<usize>,
185 ) -> Self {
186 self.kind = ModelKind::Adapter {
187 adapter: AdapterKind::XLora,
188 };
189 self.with_adapter(
190 xlora_model_id,
191 xlora_order,
192 no_kv_cache,
193 tgt_non_granular_index,
194 )
195 }
196
197 pub fn with_lora(mut self, lora_adapter_ids: Vec<String>) -> Self {
198 self.kind = ModelKind::Adapter {
199 adapter: AdapterKind::Lora,
200 };
201 self.lora_adapter_ids = Some(lora_adapter_ids);
202 self
203 }
204
205 pub fn hf_cache_path(mut self, hf_cache_path: PathBuf) -> Self {
206 self.hf_cache_path = Some(hf_cache_path);
207 self
208 }
209
210 pub fn build(self, loader_tp: Option<NormalLoaderType>) -> anyhow::Result<Box<dyn Loader>> {
213 let loader: Box<dyn NormalModelLoader> = match loader_tp {
214 Some(NormalLoaderType::Mistral) => Box::new(MistralLoader),
215 Some(NormalLoaderType::Gemma) => Box::new(GemmaLoader),
216 Some(NormalLoaderType::Llama) => Box::new(LlamaLoader),
217 Some(NormalLoaderType::Mixtral) => Box::new(MixtralLoader),
218 Some(NormalLoaderType::Phi2) => Box::new(Phi2Loader),
219 Some(NormalLoaderType::Phi3) => Box::new(Phi3Loader),
220 Some(NormalLoaderType::Qwen2) => Box::new(Qwen2Loader),
221 Some(NormalLoaderType::Gemma2) => Box::new(Gemma2Loader),
222 Some(NormalLoaderType::Starcoder2) => Box::new(Starcoder2Loader),
223 Some(NormalLoaderType::Phi3_5MoE) => Box::new(Phi3_5MoELoader),
224 Some(NormalLoaderType::DeepSeekV2) => Box::new(DeepSeekV2Loader),
225 Some(NormalLoaderType::DeepSeekV3) => Box::new(DeepSeekV3Loader),
226 Some(NormalLoaderType::Qwen3) => Box::new(Qwen3Loader),
227 Some(NormalLoaderType::GLM4) => Box::new(GLM4Loader),
228 Some(NormalLoaderType::Qwen3Moe) => Box::new(Qwen3MoELoader),
229 Some(NormalLoaderType::SmolLm3) => Box::new(SmolLm3Loader),
230 None => Box::new(AutoNormalLoader),
231 };
232 Ok(Box::new(NormalLoader {
233 inner: loader,
234 model_id: self.model_id.unwrap(),
235 config: self.config,
236 xlora_model_id: self.xlora_model_id,
237 lora_adapter_ids: self.lora_adapter_ids,
238 kind: self.kind,
239 xlora_order: self.xlora_order,
240 no_kv_cache: self.no_kv_cache,
241 chat_template: self.chat_template,
242 tokenizer_json: self.tokenizer_json,
243 tgt_non_granular_index: self.tgt_non_granular_index,
244 jinja_explicit: self.jinja_explicit,
245 token_source: RwLock::new(None),
246 revision: RwLock::new(None),
247 from_uqff: RwLock::new(None),
248 hf_cache_path: self.hf_cache_path,
249 }))
250 }
251}
252
253impl Loader for NormalLoader {
254 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
255 fn load_model_from_hf(
256 &self,
257 revision: Option<String>,
258 token_source: TokenSource,
259 dtype: &dyn TryIntoDType,
260 device: &Device,
261 silent: bool,
262 mapper: DeviceMapSetting,
263 in_situ_quant: Option<IsqType>,
264 paged_attn_config: Option<PagedAttentionConfig>,
265 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
266 let cache = self
267 .hf_cache_path
268 .clone()
269 .map(Cache::new)
270 .unwrap_or_default();
271 GLOBAL_HF_CACHE.get_or_init(|| cache);
272
273 let paths: anyhow::Result<Box<dyn ModelPaths>> = get_paths!(
274 LocalModelPaths,
275 &token_source,
276 revision.clone(),
277 self,
278 None,
279 None,
280 silent,
281 self.config.from_uqff.is_some()
282 );
283 if let Some(from_uqff) = self.config.from_uqff.clone() {
284 *self.from_uqff.write().unwrap() = Some(get_uqff_paths!(&from_uqff, self, silent));
285 }
286 *self
287 .token_source
288 .write()
289 .expect("Failed to write to token source") = Some(token_source);
290 *self.revision.write().expect("Failed to write to revision") = revision;
291 self.load_model_from_path(
292 &paths?,
293 dtype,
294 device,
295 silent,
296 mapper,
297 in_situ_quant,
298 paged_attn_config,
299 )
300 }
301
302 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
303 fn load_model_from_path(
304 &self,
305 paths: &Box<dyn ModelPaths>,
306 dtype: &dyn TryIntoDType,
307 device: &Device,
308 silent: bool,
309 mut mapper: DeviceMapSetting,
310 mut in_situ_quant: Option<IsqType>,
311 mut paged_attn_config: Option<PagedAttentionConfig>,
312 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
313 let config = std::fs::read_to_string(paths.get_config_filename())?;
314
315 if !self.inner.supports_paged_attention(&config)? {
316 paged_attn_config = None;
317 }
318
319 let prompt_chunksize = self
321 .config
322 .prompt_chunksize
323 .unwrap_or(DEFAULT_PROMPT_CHUNK_SIZE.try_into().unwrap())
324 .get();
325
326 info!("Prompt chunk size is {prompt_chunksize}.",);
327
328 let use_nccl = mistralrs_quant::distributed::use_nccl();
329
330 let available_devices = if let Ok(payload) = env::var(distributed::IS_DAEMON_FLAG) {
331 let payload: WorkerTransferData = serde_json::from_str(&payload)?;
332 let WorkerTransferData::Init { id: _, worker_rank } = payload;
333 vec![candle_core::Device::new_cuda(worker_rank + 1)?]
334 } else if use_nccl {
335 vec![candle_core::Device::new_cuda(0)?]
336 } else {
337 device_map::get_all_similar_devices(device)?
338 };
339 let device = if use_nccl || cfg!(feature = "ring") {
340 available_devices[0].clone()
341 } else {
342 device.clone()
343 };
344
345 if use_nccl || cfg!(feature = "ring") {
347 mapper = DeviceMapSetting::DummyNccl {
348 nm_device: available_devices[0].clone(),
349 };
350 } else if let DeviceMapSetting::Auto(params) = mapper.clone() {
351 let dtype = dtype.try_into_dtype(&available_devices.iter().collect::<Vec<_>>())?;
353
354 if QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)? != 1 {
356 in_situ_quant = None;
357 }
358
359 let (layer_sizes_in_bytes, non_mapped_size_in_bytes, total_model_size_in_bytes) =
362 if let Some(serialized) = &*self.from_uqff.read().unwrap() {
363 let weight_pack_factor = {
364 let ser_artifacts = unsafe {
365 candle_core::safetensors::MmapedSafetensors::multi(serialized)?
366 };
367 let mut total_pack_factors = 0;
368 let total_tensors = ser_artifacts.tensors().len();
369 for (_, artifact) in ser_artifacts.tensors() {
370 let artifact = artifact.data();
371 let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
373 let pack_factor = match QuantizedSerdeType::try_from(isq_type as usize)?
374 {
375 QuantizedSerdeType::Hqq => {
376 HqqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
377 .pack_factor(dtype)
378 }
379 QuantizedSerdeType::Gguf => {
380 GgufMatMul::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
381 .pack_factor(dtype)
382 }
383 QuantizedSerdeType::Fp8 => IsqType::F8E4M3.pack_factor(dtype),
384 QuantizedSerdeType::Unquant => 1,
385 QuantizedSerdeType::Afq => {
386 AfqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
387 .pack_factor(dtype)
388 }
389 };
390 total_pack_factors += pack_factor;
391 }
392
393 total_pack_factors / total_tensors
394 };
395
396 let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
397 &config,
398 dtype,
399 weight_pack_factor,
400 None,
401 )?;
402 let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
403 &config,
404 dtype,
405 weight_pack_factor,
406 None,
407 )?;
408 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
409 (
410 layer_sizes_in_bytes,
411 non_mapped_size_in_bytes,
412 layer_sizes_sum + non_mapped_size_in_bytes,
413 )
414 } else if let Some(isq) = in_situ_quant {
415 let weight_pack_factor = isq.pack_factor(dtype);
416 let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
417 &config,
418 dtype,
419 weight_pack_factor,
420 None,
421 )?;
422 let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
423 &config,
424 dtype,
425 weight_pack_factor,
426 None,
427 )?;
428 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
429 (
430 layer_sizes_in_bytes,
431 non_mapped_size_in_bytes,
432 layer_sizes_sum + non_mapped_size_in_bytes,
433 )
434 } else {
435 let weight_pack_factor =
437 QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)?;
438 let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
439 &config,
440 dtype,
441 weight_pack_factor,
442 None,
443 )?;
444 let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
445 &config,
446 dtype,
447 weight_pack_factor,
448 None,
449 )?;
450 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
451 (
452 layer_sizes_in_bytes,
453 non_mapped_size_in_bytes,
454 layer_sizes_sum + non_mapped_size_in_bytes,
455 )
456 };
457
458 let new = auto_device_map::get_device_layers(
459 &*self.inner,
460 &config,
461 self.inner.num_layers(&config)?,
462 layer_sizes_in_bytes,
463 non_mapped_size_in_bytes,
464 total_model_size_in_bytes,
465 &available_devices,
466 dtype,
467 ¶ms,
468 prompt_chunksize,
469 paged_attn_config.as_ref(),
470 )?;
471 mapper = DeviceMapSetting::Map(new);
472 }
473
474 let pipeline_mapper = mapper.into_mapper(
475 self.inner.num_layers(&config)?,
476 &device,
477 self.config.topology.as_ref(),
478 )?;
479 let mapper = mapper.into_mapper(
480 self.inner.num_layers(&config)?,
481 &device,
482 self.config.topology.as_ref(),
483 )?;
484 let mut layer_devices = Vec::new();
485 for layer in 0..self.inner.num_layers(&config)? {
486 let device = mapper.device_for(layer, false).cloned();
487 layer_devices.push(device);
488 }
489 let dtype = mapper.get_min_dtype(dtype)?;
490
491 let mapping_uses_cpu = mapper.get_unique_devices().iter().any(Device::is_cpu);
494 if mapping_uses_cpu && paged_attn_config.is_some() {
495 warn!("Device mapping contains a mix of GPU and CPU. There is no CPU support for PagedAttention, disabling PagedAttention.");
496 paged_attn_config = None;
497 }
498
499 info!("Model config: {:?}", self.inner.get_config_repr(&config)?);
500 if crate::using_flash_attn() {
501 once_log_info("FlashAttention is enabled.");
502 }
503
504 let mut loading_isq = if self.config.imatrix.is_none()
506 && self.config.calibration_file.is_none()
507 && !device.is_cuda()
508 && self.config.write_uqff.is_none()
509 && in_situ_quant.is_some()
510 {
511 let predicates = if matches!(self.config.organization, IsqOrganization::MoeExpertsOnly)
512 {
513 self.inner.immediate_isq_predicates_moqe(&config)?
514 } else {
515 self.inner.immediate_isq_predicates(&config)?
516 };
517 info!("Applying ISQ to {in_situ_quant:?}");
518 if predicates.is_empty() {
519 warn!("No predicates for this model and ISQ setting detected. ISQ will not be applied to any weights!");
520 }
521 mistralrs_quant::set_immediate_isq(in_situ_quant, predicates);
522 false
523 } else {
524 in_situ_quant.is_some()
525 };
526
527 if let Some(ref topology) = self.config.topology {
528 loading_isq |= topology
529 .0
530 .iter()
531 .any(|layer| layer.as_ref().is_some_and(|layer| layer.isq.is_some()));
532 }
533
534 if self.config.imatrix.is_some() && self.config.calibration_file.is_some() {
535 anyhow::bail!(
536 "`imatrix` and `calibration_file` were both specified, this is not allowed."
537 );
538 }
539
540 let load_device = if !loading_isq || self.config.calibration_file.is_some() {
542 loading_isq = false;
543 device.clone()
544 } else {
545 Device::Cpu
546 };
547
548 let is_xlora = self.kind.is_adapted_and(|a| a.is_x_lora());
549
550 let attention_mechanism = if paged_attn_config.is_some() {
551 AttentionImplementation::PagedAttention
552 } else {
553 AttentionImplementation::Eager
554 };
555
556 let multi_progress = Arc::new(MultiProgress::new());
557
558 let matformer_slicing_config = if let Some(matformer_path) =
560 &self.config.matformer_config_path
561 {
562 use crate::matformer::{MatformerConfig, MatformerSliceConfig};
563 info!("Loading Matformer config from {:?}", matformer_path);
564 let config = Arc::new(MatformerConfig::from_file(matformer_path)?);
565
566 if let Some(slice_name) = &self.config.matformer_slice_name {
567 info!("Using Matformer slice: {}", slice_name);
568 Some(MatformerSliceConfig::new(slice_name.clone(), config))
569 } else {
570 warn!("Matformer config loaded but no slice name specified. Models will use their default slice.");
573 None
574 }
575 } else {
576 None
577 };
578
579 let mut model = if use_nccl || cfg!(feature = "ring") {
580 let (mapper, sharded_vb) = distributed::prepare_distributed_mapper(
581 dtype,
582 &device,
583 &available_devices,
584 silent,
585 &config,
586 loading_isq,
587 self.config.from_uqff.is_some(),
588 self.config.organization,
589 &*self.inner,
590 paths.as_ref(),
591 )?;
592
593 match self.kind {
595 ModelKind::Normal => normal_model_loader_sharded!(
596 sharded_vb,
597 config,
598 self.inner,
599 mapper,
600 loading_isq,
601 device.clone(),
602 attention_mechanism,
603 multi_progress.clone(),
604 matformer_slicing_config.clone(),
605 ),
606 ModelKind::Adapter {
607 adapter: AdapterKind::XLora,
608 } => xlora_model_loader!(
609 paths,
610 Some(dtype),
611 &load_device,
612 layer_devices.clone(),
613 config,
614 self.inner,
615 silent,
616 mapper,
617 loading_isq,
618 device.clone(),
619 multi_progress.clone(),
620 matformer_slicing_config.clone(),
621 ),
622 ModelKind::Adapter {
623 adapter: AdapterKind::Lora,
624 } => lora_model_loader!(
625 paths,
626 Some(dtype),
627 &load_device,
628 layer_devices.clone(),
629 config,
630 self.inner,
631 silent,
632 mapper,
633 loading_isq,
634 self.config.from_uqff.is_some(),
635 device.clone(),
636 attention_mechanism,
637 matches!(self.config.organization, IsqOrganization::MoeExpertsOnly),
638 multi_progress.clone(),
639 matformer_slicing_config.clone(),
640 ),
641 _ => unreachable!(),
642 }
643 } else {
644 match self.kind {
645 ModelKind::Normal => normal_model_loader!(
646 paths,
647 Some(dtype),
648 &load_device,
649 layer_devices.clone(),
650 config,
651 self.inner,
652 silent,
653 mapper,
654 loading_isq,
655 self.config.from_uqff.is_some(),
656 device.clone(),
657 attention_mechanism,
658 matches!(self.config.organization, IsqOrganization::MoeExpertsOnly),
659 multi_progress.clone(),
660 matformer_slicing_config.clone(),
661 ),
662 ModelKind::Adapter {
663 adapter: AdapterKind::XLora,
664 } => xlora_model_loader!(
665 paths,
666 Some(dtype),
667 &load_device,
668 layer_devices.clone(),
669 config,
670 self.inner,
671 silent,
672 mapper,
673 loading_isq,
674 device.clone(),
675 multi_progress.clone(),
676 matformer_slicing_config.clone(),
677 ),
678 ModelKind::Adapter {
679 adapter: AdapterKind::Lora,
680 } => lora_model_loader!(
681 paths,
682 Some(dtype),
683 &load_device,
684 layer_devices.clone(),
685 config,
686 self.inner,
687 silent,
688 mapper,
689 loading_isq,
690 self.config.from_uqff.is_some(),
691 device.clone(),
692 attention_mechanism,
693 matches!(self.config.organization, IsqOrganization::MoeExpertsOnly),
694 multi_progress.clone(),
695 matformer_slicing_config.clone(),
696 ),
697 _ => unreachable!(),
698 }
699 };
700
701 let tokenizer = get_tokenizer(paths.get_tokenizer_filename(), None)?;
702 let gen_conf: Option<GenerationConfig> = paths.get_gen_conf_filename().and_then(|f| {
703 match serde_json::from_str::<GenerationConfig>(&fs::read_to_string(f).unwrap()) {
704 Ok(conf) => Some(conf),
705 Err(e) => {
706 warn!("Failed to parse generation_config.json: {}", e);
707 None
708 }
709 }
710 });
711
712 let chat_template_explicit = paths
713 .get_chat_template_explicit()
714 .as_ref()
715 .map(|x| x.to_string_lossy().to_string());
716 let chat_template = get_chat_template(
717 paths,
718 self.jinja_explicit.as_ref(),
719 chat_template_explicit.as_ref(),
720 self.chat_template.as_ref(),
721 None,
722 );
723
724 if let Some(calibration_file) = &self.config.calibration_file {
725 let calibration_data = std::fs::read_to_string(calibration_file)?;
726 let tokens = tokenizer
728 .encode_fast(calibration_data, false)
729 .map_err(anyhow::Error::msg)?
730 .get_ids()
731 .to_vec();
732 info!(
733 "Collecting imatrix from calibration file `{}` of {} tokens.",
734 calibration_file.display(),
735 tokens.len()
736 );
737 let bos_toks = chat_template.bos_tok().map(|b| vec![b]).unwrap_or_default();
738 let bos_tok_id = tokenizer
739 .token_to_id(&bos_toks[0])
740 .expect("Somehow the bos token is not present.");
741
742 match self.config.organization {
743 IsqOrganization::Default => model.begin_track_stats()?,
744 IsqOrganization::MoeExpertsOnly => model.begin_track_stats_moe_experts_only()?,
745 }
746
747 const CHUNK_SIZE: usize = 1024;
748 let n_chunks = tokens.len().div_ceil(CHUNK_SIZE);
749 let start = Instant::now();
750 for (i, chunk) in tokens.chunks(CHUNK_SIZE).enumerate() {
751 let chunk = [vec![bos_tok_id], chunk.to_vec()].concat();
752 let chunk_len = chunk.len();
753
754 let start = Instant::now();
755 let inputs = make_prompt_chunk(
756 0,
757 vec![&chunk],
758 &[0],
759 &load_device,
760 None,
761 false,
762 None,
763 Some(pipeline_mapper.as_ref()),
764 )?;
765
766 model.forward(
767 &inputs.input.to_device(model.device())?,
768 &inputs.positions,
769 inputs.context_lens.clone(),
770 inputs.position_ids.clone(),
771 None,
772 &inputs.flash_meta.clone(),
773 )?;
774
775 match model.cache_mut() {
776 EitherCache::Full(full) => {
777 for layer in &mut *full.lock() {
778 *layer = None
779 }
780 }
781 EitherCache::Normal(normal) => {
782 for layer in &mut *normal.lock().unwrap().0 {
783 layer.reset();
784 }
785 }
786 }
787
788 let end = Instant::now();
789 info!(
790 "Processed chunk {}/{n_chunks} ({chunk_len} tokens), {:.2}s",
791 i + 1,
792 end.duration_since(start).as_secs_f32()
793 );
794 }
795 load_device.synchronize()?;
796 let end = Instant::now();
797 info!(
798 "Finished collecting imatrix in {:.2}s",
799 end.duration_since(start).as_secs_f32()
800 );
801 }
802
803 if (loading_isq || self.config.topology.is_some()) && self.config.from_uqff.is_none() {
805 let imatrix_source = match (
806 self.config.imatrix.as_ref(),
807 self.config.calibration_file.is_some(),
808 ) {
809 (None, false) => None,
810 (Some(file), false) => Some(ImatrixDataSource::File(file)),
811 (None, true) => Some(ImatrixDataSource::Collected),
812 (Some(_), true) => unreachable!(),
813 };
814
815 info!("Applying ISQ to all ranks.");
816
817 let multi_progress = Arc::new(MultiProgress::new());
818
819 model.quantize(
820 in_situ_quant,
821 model.device().clone(),
822 self.config.topology.as_ref(),
823 silent,
824 imatrix_source,
825 self.config.organization,
826 self.config.write_uqff.as_ref(),
827 UqffFullSer {
828 tokenizer: &tokenizer,
829 template_filename: paths.get_template_filename(),
830 generation_config: paths.get_gen_conf_filename(),
831 config: config.clone(),
832 processor_filename: &None,
833 preprocessor_filename: &None,
834 },
835 multi_progress.clone(),
836 )?;
837 } else if let Some(from_uqff) = &*self.from_uqff.read().unwrap() {
838 model.load_from_artifacts(
839 device.clone(),
840 self.config.topology.as_ref(),
841 silent,
842 from_uqff,
843 )?;
844 }
845
846 let paged_attn_config = if matches!(
847 self.kind,
848 ModelKind::Adapter {
849 adapter: AdapterKind::XLora
850 }
851 ) {
852 warn!(
853 "Adapter parallel_models do not currently support PagedAttention, running without"
854 );
855 None
856 } else {
857 paged_attn_config
858 };
859
860 let (cache_config, cache_engine) = if let Some(paged_attn_config) = paged_attn_config {
861 let cache_config = calculate_cache_config(
862 paged_attn_config.mem_gpu,
863 paged_attn_config.mem_cpu,
864 paged_attn_config.block_size,
865 dtype,
866 paged_attn_config.cache_type,
867 model.config(),
868 &device,
869 &pipeline_mapper
870 .get_unique_devices()
871 .into_iter()
872 .map(Some)
873 .collect::<Vec<_>>(),
874 silent,
875 )?;
876
877 let mut layer_devices = Vec::new();
878 for layer in 0..self.inner.num_layers(&config)? {
879 let device = model.get_layers().1.device_for(layer, false).cloned();
880 layer_devices.push(device);
881 }
882 let cache_engine = CacheEngine::new(
883 model.config(),
884 &cache_config,
885 dtype,
886 model.device(),
887 layer_devices.clone(),
888 )?;
889
890 (Some(cache_config), Some(cache_engine))
891 } else {
892 (None, None)
893 };
894
895 let max_seq_len = model.max_seq_len();
896 let llg_factory = build_llg_factory(tokenizer.clone())?;
897 let num_hidden_layers = match model.cache() {
898 EitherCache::Full(full) => full.lock().len(),
899 EitherCache::Normal(normal) => normal.lock().unwrap().0.len(),
900 };
901 let eos = calculate_eos_tokens(&chat_template, gen_conf, &tokenizer);
902 let sliding_window = model.config().sliding_window;
903 let model_metadata = Arc::new(model.config().clone());
904
905 Ok(Arc::new(Mutex::new(NormalPipeline {
906 model,
907 tokenizer: tokenizer.into(),
908 no_kv_cache: self.no_kv_cache,
909 chat_template: Arc::new(chat_template),
910 non_granular_state: self.tgt_non_granular_index.map(|tgt_non_granular_index| {
911 NonGranularState {
912 non_granular_index: Arc::new(Mutex::new(0)),
913 tgt_non_granular_index,
914 }
915 }),
916 model_id: self.model_id.clone(),
917 metadata: Arc::new(GeneralMetadata {
918 max_seq_len,
919 llg_factory: Some(llg_factory),
920 no_kv_cache: self.no_kv_cache,
921 no_prefix_cache: is_xlora,
922 num_hidden_layers,
923 eos_tok: eos,
924 kind: self.kind.clone(),
925 is_xlora,
926 activation_dtype: dtype,
927 sliding_window,
928 cache_config,
929 cache_engine,
930 prompt_chunksize: Some(NonZero::new(prompt_chunksize).unwrap()),
931 model_metadata: Some(model_metadata),
932 modalities: Modalities {
933 input: vec![SupportedModality::Text],
934 output: vec![SupportedModality::Text],
935 },
936 }),
937 topology: self.config.topology.clone(),
938 silent,
939 organization: self.config.organization,
940 template_filename: paths.get_template_filename().clone(),
941 generation_config: paths.get_gen_conf_filename().cloned(),
942 config,
943 imatrix: self.config.imatrix.clone(),
944 mapper: pipeline_mapper,
945 })))
946 }
947
948 fn get_id(&self) -> String {
949 self.model_id.clone()
950 }
951
952 fn get_kind(&self) -> ModelKind {
953 self.kind.clone()
954 }
955}
956
957impl PreProcessingMixin for NormalPipeline {
958 fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
959 Some(self.chat_template.clone())
960 }
961 fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
962 None
963 }
964}
965
966impl IsqPipelineMixin for NormalPipeline {
967 fn re_isq_model(&mut self, dtype: IsqType) -> Result<()> {
968 let device = self.device().clone();
969 let multi_progress = Arc::new(MultiProgress::new());
970 self.model.quantize(
971 Some(dtype),
972 device.clone(),
973 self.topology.as_ref(),
974 self.silent,
975 self.imatrix.as_ref().map(ImatrixDataSource::File),
976 self.organization,
977 None,
978 UqffFullSer {
979 tokenizer: &self.tokenizer,
980 template_filename: &self.template_filename,
981 generation_config: self.generation_config.as_ref(),
982 config: self.config.clone(),
983 processor_filename: &None,
984 preprocessor_filename: &None,
985 },
986 multi_progress.clone(),
987 )?;
988 Ok(())
989 }
990}
991
992impl CacheManagerMixin for NormalPipeline {
993 fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
994 if matches!(self.model.cache(), EitherCache::Full(_)) {
995 FullCacheManager.clone_in_cache(self, seqs, false)
996 } else {
997 NormalCacheManager.clone_in_cache(self, seqs, false)
998 }
999 }
1000 fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
1001 if matches!(self.model.cache(), EitherCache::Full(_)) {
1002 FullCacheManager.clone_out_cache(self, seqs, false)
1003 } else {
1004 NormalCacheManager.clone_out_cache(self, seqs, false)
1005 }
1006 }
1007 fn set_none_cache(
1008 &self,
1009 seqs: &mut [&mut Sequence],
1010 reset_non_granular: bool,
1011 modify_draft_cache: bool,
1012 load_preallocated_cache: bool,
1013 ) {
1014 if matches!(self.model.cache(), EitherCache::Full(_)) {
1015 FullCacheManager.set_none_cache(self, seqs, modify_draft_cache, false);
1016 } else {
1017 NormalCacheManager.set_none_cache(
1018 self,
1019 seqs,
1020 modify_draft_cache,
1021 load_preallocated_cache,
1022 );
1023 }
1024 if reset_non_granular {
1025 self.reset_non_granular_state()
1026 }
1027 }
1028 fn cache(&self) -> &EitherCache {
1029 self.model.cache()
1030 }
1031}
1032
1033impl MetadataMixin for NormalPipeline {
1034 fn device(&self) -> Device {
1035 self.model.device().clone()
1036 }
1037 fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
1038 Some(self.tokenizer.clone())
1039 }
1040 fn name(&self) -> String {
1041 self.model_id.clone()
1042 }
1043 fn reset_non_granular_state(&self) {
1044 if let Some(s) = self.non_granular_state.as_ref() {
1045 *self.cache().full().get_scalings_cache() = None;
1046 *get_mut_arcmutex!(s.non_granular_index) = 0;
1047 }
1048 }
1049 fn get_metadata(&self) -> Arc<GeneralMetadata> {
1050 self.metadata.clone()
1051 }
1052 fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
1053 Some(&*self.mapper)
1054 }
1055}
1056
1057#[async_trait::async_trait]
1058impl Pipeline for NormalPipeline {
1059 fn forward_inputs(
1060 &mut self,
1061 inputs: Box<dyn Any>,
1062 return_raw_logits: bool,
1063 ) -> Result<ForwardInputsResult, candle_core::Error> {
1064 let ModelInputs {
1065 input_ids,
1066 input_ids_full,
1067 seqlen_offsets,
1068 seqlen_offsets_full,
1069 context_lens,
1070 position_ids,
1071 paged_attn_meta,
1072 flash_meta,
1073 flash_meta_full,
1074 } = *inputs.downcast().expect("Downcast failed.");
1075 let metadata = self.get_metadata();
1076 let paged_attn_meta = match (&metadata.cache_engine, &paged_attn_meta) {
1077 (Some(cache_engine), Some(meta)) => Some((cache_engine, meta)),
1078 (Some(_), None) => {
1079 candle_core::bail!("Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.")
1081 }
1082 (None, Some(_)) => {
1083 candle_core::bail!("Forward step got a PagedAttention input metadata but there is no cache engine. Please raise an issue.")
1085 }
1086 (None, None) => None,
1087 };
1088 let logits = match self.model.is_xlora() {
1089 false => {
1090 let paged_attn_meta = paged_attn_meta
1091 .as_ref()
1092 .map(|meta| (meta.0.get_kv_cache().clone(), meta.1.clone()));
1093
1094 self.model.forward(
1095 &input_ids,
1096 &seqlen_offsets,
1097 context_lens,
1098 position_ids,
1099 paged_attn_meta.as_ref().map(|(a, b)| (a.clone(), b)),
1100 &flash_meta,
1101 )?
1102 }
1103 true => self.model.xlora_forward(
1104 &input_ids,
1105 input_ids_full.as_ref().unwrap_or(&input_ids),
1106 &seqlen_offsets,
1107 seqlen_offsets_full.as_ref().unwrap_or(&seqlen_offsets),
1108 self.no_kv_cache,
1109 &self.non_granular_state,
1110 context_lens,
1111 position_ids,
1112 &flash_meta,
1113 flash_meta_full.as_ref().unwrap_or(&flash_meta),
1114 )?,
1115 };
1116 if return_raw_logits {
1117 Ok(ForwardInputsResult::RawLogits { logits })
1118 } else {
1119 Ok(ForwardInputsResult::CausalGeneration { logits })
1120 }
1121 }
1122 async fn sample_causal_gen(
1123 &self,
1124 seqs: &mut [&mut Sequence],
1125 logits: Vec<Tensor>,
1126 prefix_cacher: &mut PrefixCacheManagerV2,
1127 disable_eos_stop: bool,
1128 rng: Arc<std::sync::Mutex<Isaac64Rng>>,
1129 ) -> Result<(), candle_core::Error> {
1130 sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
1131 }
1132 fn category(&self) -> ModelCategory {
1133 ModelCategory::Text
1134 }
1135}
1136
1137impl AnyMoePipelineMixin for NormalPipeline {
1138 fn amoe_finish_training(&mut self, gate_model_id: Option<String>) -> candle_core::Result<()> {
1139 self.model.finish_training(gate_model_id)
1140 }
1141 fn amoe_layer_vars(&self) -> Vec<Vec<Var>> {
1142 self.model.get_vars()
1143 }
1144 fn amoe_base_model_trainable_params(&self) -> usize {
1145 self.model.trainable_params()
1146 }
1147 fn amoe_take_cached_gating_outputs(&mut self) -> Vec<Tensor> {
1148 self.model.take_cached_gating_outputs()
1149 }
1150 fn amoe_create_layers(
1151 &mut self,
1152 model_ids: Vec<String>,
1153 token: &TokenSource,
1154 revision: Option<String>,
1155 match_regex: &str,
1156 config: crate::amoe::AnyMoeConfig,
1157 dtype: candle_core::DType,
1158 dev: &Device,
1159 (prefix, mlp): (String, String),
1160 layers: Vec<usize>,
1161 expert_type: AnyMoeExpertType,
1162 silent: bool,
1163 gate_model_id: Option<String>,
1164 ) -> candle_core::Result<()> {
1165 let mut vbs = Vec::new();
1166 let regex = Regex::new(match_regex).map_err(candle_core::Error::msg)?;
1168 for model_id in model_ids {
1169 let model_id_str = &model_id;
1170 let model_id = Path::new(&model_id);
1171
1172 let api = {
1173 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1174 let mut api = ApiBuilder::from_cache(cache)
1175 .with_progress(!silent)
1176 .with_token(get_token(token).map_err(candle_core::Error::msg)?);
1177 if let Ok(x) = std::env::var("HF_HUB_CACHE") {
1178 api = api.with_cache_dir(x.into());
1179 }
1180 api.build().map_err(candle_core::Error::msg)?
1181 };
1182 let revision = revision.clone().unwrap_or("main".to_string());
1183 let api = api.repo(Repo::with_revision(
1184 model_id_str.clone(),
1185 RepoType::Model,
1186 revision.clone(),
1187 ));
1188
1189 let mut filenames = vec![];
1190 for rfilename in
1191 api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors"))
1192 {
1193 filenames.push(api_get_file!(api, &rfilename, model_id));
1194 }
1195
1196 let regex = regex.clone();
1197 let match_regex_clone = match_regex.to_string();
1198 let layers_clone = layers.clone();
1199 let vb = from_mmaped_safetensors(
1200 filenames,
1201 vec![],
1202 Some(dtype),
1203 dev,
1204 vec![None],
1205 silent,
1206 None,
1207 move |key| {
1208 if regex.is_match(&key) {
1209 let last_layer_idx = key.find(&match_regex_clone).unwrap() - 1;
1212 let first_layer_idx = key[..last_layer_idx].rfind('.').unwrap();
1213 let layer_n = key[first_layer_idx + 1..last_layer_idx]
1214 .parse::<usize>()
1215 .unwrap();
1216 layers_clone.contains(&layer_n) || layers_clone.is_empty()
1217 } else {
1218 false
1219 }
1220 },
1221 Arc::new(|_| DeviceForLoadTensor::Base),
1222 )?;
1223 vbs.push(vb);
1224 }
1225
1226 let gate_vb = if let Some(gate_model_id) = gate_model_id {
1227 let model_id_str = &gate_model_id;
1228 let model_id = Path::new(&gate_model_id);
1229
1230 let api = {
1231 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1232 let mut api = ApiBuilder::from_cache(cache)
1233 .with_progress(!silent)
1234 .with_token(get_token(token).map_err(candle_core::Error::msg)?);
1235 if let Ok(x) = std::env::var("HF_HUB_CACHE") {
1236 api = api.with_cache_dir(x.into());
1237 }
1238 api.build().map_err(candle_core::Error::msg)?
1239 };
1240 let revision = revision.clone().unwrap_or("main".to_string());
1241 let api = api.repo(Repo::with_revision(
1242 model_id_str.clone(),
1243 RepoType::Model,
1244 revision.clone(),
1245 ));
1246
1247 let mut gate_filenames = vec![];
1248 for rfilename in
1249 api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors"))
1250 {
1251 gate_filenames.push(api_get_file!(api, &rfilename, model_id));
1252 }
1253 assert_eq!(
1254 gate_filenames.len(),
1255 1,
1256 "Gate model ID must contain only one .safetensors file"
1257 );
1258
1259 let vb = from_mmaped_safetensors(
1260 gate_filenames.clone(),
1261 vec![],
1262 Some(dtype),
1263 dev,
1264 vec![None],
1265 silent,
1266 None,
1267 |_| true,
1268 Arc::new(|_| DeviceForLoadTensor::Base),
1269 )?;
1270 info!(
1271 "Loaded gating layers from `{}`",
1272 gate_filenames[0].display()
1273 );
1274 Some(vb)
1275 } else {
1276 None
1277 };
1278
1279 self.model.create_anymoe_layers(
1280 vbs.clone(),
1281 config.clone(),
1282 (prefix.clone(), mlp.clone()),
1283 layers.clone(),
1284 expert_type.clone(),
1285 gate_vb.clone(),
1286 )?;
1287
1288 Ok(())
1289 }
1290 fn amoe_supported(&self) -> bool {
1291 self.model.amoe_supported()
1292 }
1293}