1use super::isq::ImatrixDataSource;
2use super::llg::build_llg_factory;
3use super::{
4 get_model_paths, get_xlora_paths, text_models_inputs_processor::ModelInputs, AdapterKind,
5 CacheManager, GeneralMetadata, Loader, ModelKind, ModelPaths, NormalModel, NormalModelLoader,
6 TokenSource,
7};
8use super::{
9 AnyMoePipelineMixin, CacheManagerMixin, EitherCache, ForwardInputsResult, IsqOrganization,
10 IsqPipelineMixin, MetadataMixin, ModelCategory, PreProcessingMixin,
11};
12use super::{
13 AutoNormalLoader, DeepSeekV2Loader, DeepSeekV3Loader, GLM4Loader, Gemma2Loader, GemmaLoader,
14 LlamaLoader, MistralLoader, MixtralLoader, NormalLoaderType, Phi2Loader, Phi3Loader,
15 Phi3_5MoELoader, Qwen2Loader, Qwen3Loader, Qwen3MoELoader, SmolLm3Loader, Starcoder2Loader,
16};
17use crate::amoe::AnyMoeExpertType;
18use crate::attention::ATTENTION_CHUNK_SIZE;
19use crate::device_map::{self, DeviceMapper};
20use crate::distributed::{self, WorkerTransferData};
21use crate::kv_cache::{FullCacheManager, NormalCacheManager};
22use crate::lora::Ordering;
23use crate::paged_attention::{calculate_cache_config, AttentionImplementation, CacheEngine};
24use crate::pipeline::chat_template::{calculate_eos_tokens, GenerationConfig};
25use crate::pipeline::isq::UqffFullSer;
26use crate::pipeline::loaders::auto_device_map;
27use crate::pipeline::loaders::QuantizationConfigShim;
28use crate::pipeline::sampling::sample_and_add_toks;
29use crate::pipeline::text_models_inputs_processor::make_prompt_chunk;
30use crate::pipeline::{get_chat_template, Modalities, SupportedModality};
31use crate::pipeline::{ChatTemplate, LocalModelPaths};
32use crate::prefix_cacher::PrefixCacheManagerV2;
33use crate::sequence::Sequence;
34use crate::utils::tokenizer::get_tokenizer;
35use crate::utils::varbuilder_utils::DeviceForLoadTensor;
36use crate::utils::{tokens::get_token, varbuilder_utils::from_mmaped_safetensors};
37use crate::xlora_models::NonGranularState;
38use crate::{
39 api_dir_list, api_get_file, get_mut_arcmutex, get_paths, get_uqff_paths, lora_model_loader,
40 normal_model_loader, normal_model_loader_sharded, xlora_model_loader, DeviceMapSetting,
41 PagedAttentionConfig, Pipeline, Topology, TryIntoDType, GLOBAL_HF_CACHE,
42};
43use anyhow::Result;
44use candle_core::{Device, Tensor, Var};
45use hf_hub::Cache;
46use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
47use indicatif::MultiProgress;
48use mistralrs_quant::log::once_log_info;
49use mistralrs_quant::{AfqLayer, GgufMatMul, HqqLayer, IsqType, QuantizedSerdeType};
50use rand_isaac::Isaac64Rng;
51use regex_automata::meta::Regex;
52use std::any::Any;
53use std::borrow::Cow;
54use std::path::{Path, PathBuf};
55use std::str::FromStr;
56use std::sync::{Arc, RwLock};
57use std::time::Instant;
58use std::{env, fs};
59use tokenizers::Tokenizer;
60use tokio::sync::Mutex;
61use tracing::{info, warn};
62
63pub struct NormalPipeline {
64 model: Box<dyn NormalModel + Send + Sync>,
65 tokenizer: Arc<Tokenizer>,
66 no_kv_cache: bool,
67 chat_template: Arc<ChatTemplate>,
68 non_granular_state: Option<NonGranularState>,
69 model_id: String,
70 metadata: Arc<GeneralMetadata>,
71 topology: Option<Topology>,
72 silent: bool,
73 organization: IsqOrganization,
74 template_filename: Option<PathBuf>,
76 generation_config: Option<PathBuf>,
77 config: String,
78 imatrix: Option<PathBuf>,
79 mapper: Box<dyn DeviceMapper + Send + Sync>,
80}
81
82pub struct NormalLoader {
84 inner: Box<dyn NormalModelLoader>,
85 model_id: String,
86 config: NormalSpecificConfig,
87 xlora_model_id: Option<String>,
88 lora_adapter_ids: Option<Vec<String>>,
89 kind: ModelKind,
90 xlora_order: Option<Ordering>,
91 no_kv_cache: bool,
92 chat_template: Option<String>,
93 tokenizer_json: Option<String>,
94 tgt_non_granular_index: Option<usize>,
95 token_source: RwLock<Option<TokenSource>>,
96 revision: RwLock<Option<String>>,
97 from_uqff: RwLock<Option<Vec<PathBuf>>>,
98 jinja_explicit: Option<String>,
99 hf_cache_path: Option<PathBuf>,
100}
101
102#[derive(Default)]
103pub struct NormalLoaderBuilder {
105 model_id: Option<String>,
106 config: NormalSpecificConfig,
107 xlora_model_id: Option<String>,
108 lora_adapter_ids: Option<Vec<String>>,
109 kind: ModelKind,
110 xlora_order: Option<Ordering>,
111 no_kv_cache: bool,
112 chat_template: Option<String>,
113 tokenizer_json: Option<String>,
114 tgt_non_granular_index: Option<usize>,
115 jinja_explicit: Option<String>,
116 hf_cache_path: Option<PathBuf>,
117}
118
119#[derive(Clone, Default)]
120pub struct NormalSpecificConfig {
122 pub topology: Option<Topology>,
123 pub organization: IsqOrganization,
124 pub write_uqff: Option<PathBuf>,
125 pub from_uqff: Option<Vec<PathBuf>>,
126 pub imatrix: Option<PathBuf>,
127 pub calibration_file: Option<PathBuf>,
128 pub hf_cache_path: Option<PathBuf>,
129 pub matformer_config_path: Option<PathBuf>,
130 pub matformer_slice_name: Option<String>,
131}
132
133impl NormalLoaderBuilder {
134 pub fn new(
135 config: NormalSpecificConfig,
136 chat_template: Option<String>,
137 tokenizer_json: Option<String>,
138 model_id: Option<String>,
139 no_kv_cache: bool,
140 jinja_explicit: Option<String>,
141 ) -> Self {
142 Self {
143 config,
144 chat_template,
145 tokenizer_json,
146 model_id,
147 kind: ModelKind::Normal,
148 jinja_explicit,
149 no_kv_cache,
150 ..Default::default()
151 }
152 }
153
154 fn with_adapter(
155 mut self,
156 xlora_model_id: String,
157 xlora_order: Ordering,
158 no_kv_cache: bool,
159 tgt_non_granular_index: Option<usize>,
160 ) -> Self {
161 self.xlora_model_id = Some(xlora_model_id);
162 self.xlora_order = Some(xlora_order);
163 self.no_kv_cache = no_kv_cache;
164 self.tgt_non_granular_index = tgt_non_granular_index;
165 self.model_id = if let Some(id) = self.model_id {
166 Some(id)
167 } else {
168 info!(
169 "Using adapter base model ID: `{}`",
170 self.xlora_order.as_ref().unwrap().base_model_id
171 );
172 Some(self.xlora_order.as_ref().unwrap().base_model_id.clone())
173 };
174 self
175 }
176
177 pub fn with_xlora(
178 mut self,
179 xlora_model_id: String,
180 xlora_order: Ordering,
181 no_kv_cache: bool,
182 tgt_non_granular_index: Option<usize>,
183 ) -> Self {
184 self.kind = ModelKind::Adapter {
185 adapter: AdapterKind::XLora,
186 };
187 self.with_adapter(
188 xlora_model_id,
189 xlora_order,
190 no_kv_cache,
191 tgt_non_granular_index,
192 )
193 }
194
195 pub fn with_lora(mut self, lora_adapter_ids: Vec<String>) -> Self {
196 self.kind = ModelKind::Adapter {
197 adapter: AdapterKind::Lora,
198 };
199 self.lora_adapter_ids = Some(lora_adapter_ids);
200 self
201 }
202
203 pub fn hf_cache_path(mut self, hf_cache_path: PathBuf) -> Self {
204 self.hf_cache_path = Some(hf_cache_path);
205 self
206 }
207
208 pub fn build(self, loader_tp: Option<NormalLoaderType>) -> anyhow::Result<Box<dyn Loader>> {
211 let loader: Box<dyn NormalModelLoader> = match loader_tp {
212 Some(NormalLoaderType::Mistral) => Box::new(MistralLoader),
213 Some(NormalLoaderType::Gemma) => Box::new(GemmaLoader),
214 Some(NormalLoaderType::Llama) => Box::new(LlamaLoader),
215 Some(NormalLoaderType::Mixtral) => Box::new(MixtralLoader),
216 Some(NormalLoaderType::Phi2) => Box::new(Phi2Loader),
217 Some(NormalLoaderType::Phi3) => Box::new(Phi3Loader),
218 Some(NormalLoaderType::Qwen2) => Box::new(Qwen2Loader),
219 Some(NormalLoaderType::Gemma2) => Box::new(Gemma2Loader),
220 Some(NormalLoaderType::Starcoder2) => Box::new(Starcoder2Loader),
221 Some(NormalLoaderType::Phi3_5MoE) => Box::new(Phi3_5MoELoader),
222 Some(NormalLoaderType::DeepSeekV2) => Box::new(DeepSeekV2Loader),
223 Some(NormalLoaderType::DeepSeekV3) => Box::new(DeepSeekV3Loader),
224 Some(NormalLoaderType::Qwen3) => Box::new(Qwen3Loader),
225 Some(NormalLoaderType::GLM4) => Box::new(GLM4Loader),
226 Some(NormalLoaderType::Qwen3Moe) => Box::new(Qwen3MoELoader),
227 Some(NormalLoaderType::SmolLm3) => Box::new(SmolLm3Loader),
228 None => Box::new(AutoNormalLoader),
229 };
230 Ok(Box::new(NormalLoader {
231 inner: loader,
232 model_id: self.model_id.unwrap(),
233 config: self.config,
234 xlora_model_id: self.xlora_model_id,
235 lora_adapter_ids: self.lora_adapter_ids,
236 kind: self.kind,
237 xlora_order: self.xlora_order,
238 no_kv_cache: self.no_kv_cache,
239 chat_template: self.chat_template,
240 tokenizer_json: self.tokenizer_json,
241 tgt_non_granular_index: self.tgt_non_granular_index,
242 jinja_explicit: self.jinja_explicit,
243 token_source: RwLock::new(None),
244 revision: RwLock::new(None),
245 from_uqff: RwLock::new(None),
246 hf_cache_path: self.hf_cache_path,
247 }))
248 }
249}
250
251impl Loader for NormalLoader {
252 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
253 fn load_model_from_hf(
254 &self,
255 revision: Option<String>,
256 token_source: TokenSource,
257 dtype: &dyn TryIntoDType,
258 device: &Device,
259 silent: bool,
260 mapper: DeviceMapSetting,
261 in_situ_quant: Option<IsqType>,
262 paged_attn_config: Option<PagedAttentionConfig>,
263 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
264 let cache = self
265 .hf_cache_path
266 .clone()
267 .map(Cache::new)
268 .unwrap_or_default();
269 GLOBAL_HF_CACHE.get_or_init(|| cache);
270
271 let paths: anyhow::Result<Box<dyn ModelPaths>> = get_paths!(
272 LocalModelPaths,
273 &token_source,
274 revision.clone(),
275 self,
276 None,
277 None,
278 silent,
279 self.config.from_uqff.is_some()
280 );
281 if let Some(from_uqff) = self.config.from_uqff.clone() {
282 *self.from_uqff.write().unwrap() = Some(get_uqff_paths!(&from_uqff, self, silent));
283 }
284 *self
285 .token_source
286 .write()
287 .expect("Failed to write to token source") = Some(token_source);
288 *self.revision.write().expect("Failed to write to revision") = revision;
289 self.load_model_from_path(
290 &paths?,
291 dtype,
292 device,
293 silent,
294 mapper,
295 in_situ_quant,
296 paged_attn_config,
297 )
298 }
299
300 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
301 fn load_model_from_path(
302 &self,
303 paths: &Box<dyn ModelPaths>,
304 dtype: &dyn TryIntoDType,
305 device: &Device,
306 silent: bool,
307 mut mapper: DeviceMapSetting,
308 mut in_situ_quant: Option<IsqType>,
309 mut paged_attn_config: Option<PagedAttentionConfig>,
310 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
311 let config = std::fs::read_to_string(paths.get_config_filename())?;
312
313 if !self.inner.supports_paged_attention(&config)? {
314 paged_attn_config = None;
315 }
316
317 info!("Prompt chunk size is {ATTENTION_CHUNK_SIZE}.");
318
319 let use_nccl = mistralrs_quant::distributed::use_nccl();
320
321 let available_devices = if let Ok(payload) = env::var(distributed::IS_DAEMON_FLAG) {
322 let payload: WorkerTransferData = serde_json::from_str(&payload)?;
323 let WorkerTransferData::Init { id: _, worker_rank } = payload;
324 vec![candle_core::Device::new_cuda(worker_rank + 1)?]
325 } else if use_nccl {
326 vec![candle_core::Device::new_cuda(0)?]
327 } else {
328 device_map::get_all_similar_devices(device)?
329 };
330 #[cfg(feature = "cuda")]
331 for device in &available_devices {
332 if let Device::Cuda(dev) = device {
333 unsafe { dev.disable_event_tracking() };
334 }
335 }
336 let device = if use_nccl || cfg!(feature = "ring") {
337 available_devices[0].clone()
338 } else {
339 device.clone()
340 };
341
342 if use_nccl || cfg!(feature = "ring") {
344 mapper = DeviceMapSetting::DummyNccl {
345 nm_device: available_devices[0].clone(),
346 };
347 } else if let DeviceMapSetting::Auto(params) = mapper.clone() {
348 let dtype = dtype.try_into_dtype(&available_devices.iter().collect::<Vec<_>>())?;
350
351 if QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)? != 1 {
353 in_situ_quant = None;
354 }
355
356 let (layer_sizes_in_bytes, non_mapped_size_in_bytes, total_model_size_in_bytes) =
359 if let Some(serialized) = &*self.from_uqff.read().unwrap() {
360 let weight_pack_factor = {
361 let ser_artifacts = unsafe {
362 candle_core::safetensors::MmapedSafetensors::multi(serialized)?
363 };
364 let mut total_pack_factors = 0;
365 let total_tensors = ser_artifacts.tensors().len();
366 for (_, artifact) in ser_artifacts.tensors() {
367 let artifact = artifact.data();
368 let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
370 let pack_factor = match QuantizedSerdeType::try_from(isq_type as usize)?
371 {
372 QuantizedSerdeType::Hqq => {
373 HqqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
374 .pack_factor(dtype)
375 }
376 QuantizedSerdeType::Gguf => {
377 GgufMatMul::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
378 .pack_factor(dtype)
379 }
380 QuantizedSerdeType::Fp8 => IsqType::F8E4M3.pack_factor(dtype),
381 QuantizedSerdeType::Unquant => 1,
382 QuantizedSerdeType::Afq => {
383 AfqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
384 .pack_factor(dtype)
385 }
386 };
387 total_pack_factors += pack_factor;
388 }
389
390 total_pack_factors / total_tensors
391 };
392
393 let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
394 &config,
395 dtype,
396 weight_pack_factor,
397 None,
398 )?;
399 let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
400 &config,
401 dtype,
402 weight_pack_factor,
403 None,
404 )?;
405 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
406 (
407 layer_sizes_in_bytes,
408 non_mapped_size_in_bytes,
409 layer_sizes_sum + non_mapped_size_in_bytes,
410 )
411 } else if let Some(isq) = in_situ_quant {
412 let weight_pack_factor = isq.pack_factor(dtype);
413 let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
414 &config,
415 dtype,
416 weight_pack_factor,
417 None,
418 )?;
419 let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
420 &config,
421 dtype,
422 weight_pack_factor,
423 None,
424 )?;
425 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
426 (
427 layer_sizes_in_bytes,
428 non_mapped_size_in_bytes,
429 layer_sizes_sum + non_mapped_size_in_bytes,
430 )
431 } else {
432 let weight_pack_factor =
434 QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)?;
435 let layer_sizes_in_bytes = self.inner.layer_sizes_in_bytes(
436 &config,
437 dtype,
438 weight_pack_factor,
439 None,
440 )?;
441 let non_mapped_size_in_bytes = self.inner.non_mapped_size_in_bytes(
442 &config,
443 dtype,
444 weight_pack_factor,
445 None,
446 )?;
447 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
448 (
449 layer_sizes_in_bytes,
450 non_mapped_size_in_bytes,
451 layer_sizes_sum + non_mapped_size_in_bytes,
452 )
453 };
454
455 let new = auto_device_map::get_device_layers(
456 &*self.inner,
457 &config,
458 self.inner.num_layers(&config)?,
459 layer_sizes_in_bytes,
460 non_mapped_size_in_bytes,
461 total_model_size_in_bytes,
462 &available_devices,
463 dtype,
464 ¶ms,
465 paged_attn_config.as_ref(),
466 )?;
467 mapper = DeviceMapSetting::Map(new);
468 }
469
470 let pipeline_mapper = mapper.into_mapper(
471 self.inner.num_layers(&config)?,
472 &device,
473 self.config.topology.as_ref(),
474 )?;
475 let mapper = mapper.into_mapper(
476 self.inner.num_layers(&config)?,
477 &device,
478 self.config.topology.as_ref(),
479 )?;
480 let mut layer_devices = Vec::new();
481 for layer in 0..self.inner.num_layers(&config)? {
482 let device = mapper.device_for(layer, false).cloned();
483 layer_devices.push(device);
484 }
485 let dtype = mapper.get_min_dtype(dtype)?;
486
487 let mapping_uses_cpu = mapper.get_unique_devices().iter().any(Device::is_cpu);
490 if mapping_uses_cpu && paged_attn_config.is_some() {
491 warn!("Device mapping contains a mix of GPU and CPU. There is no CPU support for PagedAttention, disabling PagedAttention.");
492 paged_attn_config = None;
493 }
494
495 info!("Model config: {:?}", self.inner.get_config_repr(&config)?);
496 if crate::using_flash_attn() {
497 once_log_info("FlashAttention is enabled.");
498 }
499
500 let mut loading_isq = if self.config.imatrix.is_none()
502 && self.config.calibration_file.is_none()
503 && !device.is_cuda()
504 && self.config.write_uqff.is_none()
505 && in_situ_quant.is_some()
506 {
507 let predicates = if matches!(self.config.organization, IsqOrganization::MoeExpertsOnly)
508 {
509 self.inner.immediate_isq_predicates_moqe(&config)?
510 } else {
511 self.inner.immediate_isq_predicates(&config)?
512 };
513 info!("Applying ISQ to {in_situ_quant:?}");
514 if predicates.is_empty() {
515 warn!("No predicates for this model and ISQ setting detected. ISQ will not be applied to any weights!");
516 }
517 mistralrs_quant::set_immediate_isq(in_situ_quant, predicates);
518 false
519 } else {
520 in_situ_quant.is_some()
521 };
522
523 if let Some(ref topology) = self.config.topology {
524 loading_isq |= topology
525 .0
526 .iter()
527 .any(|layer| layer.as_ref().is_some_and(|layer| layer.isq.is_some()));
528 }
529
530 if self.config.imatrix.is_some() && self.config.calibration_file.is_some() {
531 anyhow::bail!(
532 "`imatrix` and `calibration_file` were both specified, this is not allowed."
533 );
534 }
535
536 let load_device = if !loading_isq || self.config.calibration_file.is_some() {
538 loading_isq = false;
539 device.clone()
540 } else {
541 Device::Cpu
542 };
543
544 let is_xlora = self.kind.is_adapted_and(|a| a.is_x_lora());
545
546 let attention_mechanism = if paged_attn_config.is_some() {
547 AttentionImplementation::PagedAttention
548 } else {
549 AttentionImplementation::Eager
550 };
551
552 let multi_progress = Arc::new(MultiProgress::new());
553
554 let matformer_slicing_config = if let Some(matformer_path) =
556 &self.config.matformer_config_path
557 {
558 use crate::matformer::{MatformerConfig, MatformerSliceConfig};
559 info!("Loading Matformer config from {:?}", matformer_path);
560 let config = Arc::new(MatformerConfig::from_file(matformer_path)?);
561
562 if let Some(slice_name) = &self.config.matformer_slice_name {
563 info!("Using Matformer slice: {}", slice_name);
564 Some(MatformerSliceConfig::new(slice_name.clone(), config))
565 } else {
566 warn!("Matformer config loaded but no slice name specified. Models will use their default slice.");
569 None
570 }
571 } else {
572 None
573 };
574
575 let mut model = if use_nccl || cfg!(feature = "ring") {
576 let (mapper, sharded_vb) = distributed::prepare_distributed_mapper(
577 dtype,
578 &device,
579 &available_devices,
580 silent,
581 &config,
582 loading_isq,
583 self.config.from_uqff.is_some(),
584 self.config.organization,
585 &*self.inner,
586 paths.as_ref(),
587 )?;
588
589 match self.kind {
591 ModelKind::Normal => normal_model_loader_sharded!(
592 sharded_vb,
593 config,
594 self.inner,
595 mapper,
596 loading_isq,
597 device.clone(),
598 attention_mechanism,
599 multi_progress.clone(),
600 matformer_slicing_config.clone(),
601 ),
602 ModelKind::Adapter {
603 adapter: AdapterKind::XLora,
604 } => xlora_model_loader!(
605 paths,
606 Some(dtype),
607 &load_device,
608 layer_devices.clone(),
609 config,
610 self.inner,
611 silent,
612 mapper,
613 loading_isq,
614 device.clone(),
615 multi_progress.clone(),
616 matformer_slicing_config.clone(),
617 ),
618 ModelKind::Adapter {
619 adapter: AdapterKind::Lora,
620 } => lora_model_loader!(
621 paths,
622 Some(dtype),
623 &load_device,
624 layer_devices.clone(),
625 config,
626 self.inner,
627 silent,
628 mapper,
629 loading_isq,
630 self.config.from_uqff.is_some(),
631 device.clone(),
632 attention_mechanism,
633 matches!(self.config.organization, IsqOrganization::MoeExpertsOnly),
634 multi_progress.clone(),
635 matformer_slicing_config.clone(),
636 ),
637 _ => unreachable!(),
638 }
639 } else {
640 match self.kind {
641 ModelKind::Normal => normal_model_loader!(
642 paths,
643 Some(dtype),
644 &load_device,
645 layer_devices.clone(),
646 config,
647 self.inner,
648 silent,
649 mapper,
650 loading_isq,
651 self.config.from_uqff.is_some(),
652 device.clone(),
653 attention_mechanism,
654 matches!(self.config.organization, IsqOrganization::MoeExpertsOnly),
655 multi_progress.clone(),
656 matformer_slicing_config.clone(),
657 ),
658 ModelKind::Adapter {
659 adapter: AdapterKind::XLora,
660 } => xlora_model_loader!(
661 paths,
662 Some(dtype),
663 &load_device,
664 layer_devices.clone(),
665 config,
666 self.inner,
667 silent,
668 mapper,
669 loading_isq,
670 device.clone(),
671 multi_progress.clone(),
672 matformer_slicing_config.clone(),
673 ),
674 ModelKind::Adapter {
675 adapter: AdapterKind::Lora,
676 } => lora_model_loader!(
677 paths,
678 Some(dtype),
679 &load_device,
680 layer_devices.clone(),
681 config,
682 self.inner,
683 silent,
684 mapper,
685 loading_isq,
686 self.config.from_uqff.is_some(),
687 device.clone(),
688 attention_mechanism,
689 matches!(self.config.organization, IsqOrganization::MoeExpertsOnly),
690 multi_progress.clone(),
691 matformer_slicing_config.clone(),
692 ),
693 _ => unreachable!(),
694 }
695 };
696
697 let tokenizer = get_tokenizer(paths.get_tokenizer_filename(), None)?;
698 let gen_conf: Option<GenerationConfig> = paths.get_gen_conf_filename().and_then(|f| {
699 match serde_json::from_str::<GenerationConfig>(&fs::read_to_string(f).unwrap()) {
700 Ok(conf) => Some(conf),
701 Err(e) => {
702 warn!("Failed to parse generation_config.json: {}", e);
703 None
704 }
705 }
706 });
707
708 let chat_template_explicit = paths
709 .get_chat_template_explicit()
710 .as_ref()
711 .map(|x| x.to_string_lossy().to_string());
712 let chat_template = get_chat_template(
713 paths,
714 self.jinja_explicit.as_ref(),
715 chat_template_explicit.as_ref(),
716 self.chat_template.as_ref(),
717 None,
718 );
719
720 if let Some(calibration_file) = &self.config.calibration_file {
721 let calibration_data = std::fs::read_to_string(calibration_file)?;
722 let tokens = tokenizer
724 .encode_fast(calibration_data, false)
725 .map_err(anyhow::Error::msg)?
726 .get_ids()
727 .to_vec();
728 info!(
729 "Collecting imatrix from calibration file `{}` of {} tokens.",
730 calibration_file.display(),
731 tokens.len()
732 );
733 let bos_toks = chat_template.bos_tok().map(|b| vec![b]).unwrap_or_default();
734 let bos_tok_id = tokenizer
735 .token_to_id(&bos_toks[0])
736 .expect("Somehow the bos token is not present.");
737
738 match self.config.organization {
739 IsqOrganization::Default => model.begin_track_stats()?,
740 IsqOrganization::MoeExpertsOnly => model.begin_track_stats_moe_experts_only()?,
741 }
742
743 const CHUNK_SIZE: usize = 1024;
744 let n_chunks = tokens.len().div_ceil(CHUNK_SIZE);
745 let start = Instant::now();
746 for (i, chunk) in tokens.chunks(CHUNK_SIZE).enumerate() {
747 let chunk = [vec![bos_tok_id], chunk.to_vec()].concat();
748 let chunk_len = chunk.len();
749
750 let start = Instant::now();
751 let inputs = make_prompt_chunk(
752 0,
753 vec![&chunk],
754 &[0],
755 &load_device,
756 None,
757 false,
758 None,
759 Some(pipeline_mapper.as_ref()),
760 )?;
761
762 model.forward(
763 &inputs.input.to_device(model.device())?,
764 &inputs.positions,
765 inputs.context_lens.clone(),
766 inputs.position_ids.clone(),
767 None,
768 &inputs.flash_meta.clone(),
769 )?;
770
771 match model.cache_mut() {
772 EitherCache::Full(full) => {
773 for layer in &mut *full.lock() {
774 *layer = None
775 }
776 }
777 EitherCache::Normal(normal) => {
778 for layer in &mut *normal.lock().unwrap().0 {
779 layer.reset();
780 }
781 }
782 }
783
784 let end = Instant::now();
785 info!(
786 "Processed chunk {}/{n_chunks} ({chunk_len} tokens), {:.2}s",
787 i + 1,
788 end.duration_since(start).as_secs_f32()
789 );
790 }
791 load_device.synchronize()?;
792 let end = Instant::now();
793 info!(
794 "Finished collecting imatrix in {:.2}s",
795 end.duration_since(start).as_secs_f32()
796 );
797 }
798
799 if (loading_isq || self.config.topology.is_some()) && self.config.from_uqff.is_none() {
801 let imatrix_source = match (
802 self.config.imatrix.as_ref(),
803 self.config.calibration_file.is_some(),
804 ) {
805 (None, false) => None,
806 (Some(file), false) => Some(ImatrixDataSource::File(file)),
807 (None, true) => Some(ImatrixDataSource::Collected),
808 (Some(_), true) => unreachable!(),
809 };
810
811 info!("Applying ISQ to all ranks.");
812
813 let multi_progress = Arc::new(MultiProgress::new());
814
815 model.quantize(
816 in_situ_quant,
817 model.device().clone(),
818 self.config.topology.as_ref(),
819 silent,
820 imatrix_source,
821 self.config.organization,
822 self.config.write_uqff.as_ref(),
823 UqffFullSer {
824 tokenizer: &tokenizer,
825 template_filename: paths.get_template_filename(),
826 generation_config: paths.get_gen_conf_filename(),
827 config: config.clone(),
828 processor_filename: &None,
829 preprocessor_filename: &None,
830 },
831 multi_progress.clone(),
832 )?;
833 } else if let Some(from_uqff) = &*self.from_uqff.read().unwrap() {
834 model.load_from_artifacts(
835 device.clone(),
836 self.config.topology.as_ref(),
837 silent,
838 from_uqff,
839 )?;
840 }
841
842 let paged_attn_config = if matches!(
843 self.kind,
844 ModelKind::Adapter {
845 adapter: AdapterKind::XLora
846 }
847 ) {
848 warn!(
849 "Adapter parallel_models do not currently support PagedAttention, running without"
850 );
851 None
852 } else {
853 paged_attn_config
854 };
855
856 let (cache_config, cache_engine) = if let Some(paged_attn_config) = paged_attn_config {
857 let cache_config = calculate_cache_config(
858 paged_attn_config.mem_gpu,
859 paged_attn_config.mem_cpu,
860 paged_attn_config.block_size,
861 dtype,
862 paged_attn_config.cache_type,
863 model.config(),
864 &device,
865 &pipeline_mapper
866 .get_unique_devices()
867 .into_iter()
868 .map(Some)
869 .collect::<Vec<_>>(),
870 silent,
871 )?;
872
873 let mut layer_devices = Vec::new();
874 for layer in 0..self.inner.num_layers(&config)? {
875 let device = model.get_layers().1.device_for(layer, false).cloned();
876 layer_devices.push(device);
877 }
878 let cache_engine = CacheEngine::new(
879 model.config(),
880 &cache_config,
881 dtype,
882 model.device(),
883 layer_devices.clone(),
884 )?;
885
886 (Some(cache_config), Some(cache_engine))
887 } else {
888 (None, None)
889 };
890
891 let max_seq_len = model.max_seq_len();
892 let llg_factory = build_llg_factory(tokenizer.clone())?;
893 let num_hidden_layers = match model.cache() {
894 EitherCache::Full(full) => full.lock().len(),
895 EitherCache::Normal(normal) => normal.lock().unwrap().0.len(),
896 };
897 let eos = calculate_eos_tokens(&chat_template, gen_conf, &tokenizer);
898 let sliding_window = model.config().sliding_window;
899 let model_metadata = Arc::new(model.config().clone());
900
901 Ok(Arc::new(Mutex::new(NormalPipeline {
902 model,
903 tokenizer: tokenizer.into(),
904 no_kv_cache: self.no_kv_cache,
905 chat_template: Arc::new(chat_template),
906 non_granular_state: self.tgt_non_granular_index.map(|tgt_non_granular_index| {
907 NonGranularState {
908 non_granular_index: Arc::new(Mutex::new(0)),
909 tgt_non_granular_index,
910 }
911 }),
912 model_id: self.model_id.clone(),
913 metadata: Arc::new(GeneralMetadata {
914 max_seq_len,
915 llg_factory: Some(llg_factory),
916 no_kv_cache: self.no_kv_cache,
917 no_prefix_cache: is_xlora,
918 num_hidden_layers,
919 eos_tok: eos,
920 kind: self.kind.clone(),
921 is_xlora,
922 activation_dtype: dtype,
923 sliding_window,
924 cache_config,
925 cache_engine,
926 model_metadata: Some(model_metadata),
927 modalities: Modalities {
928 input: vec![SupportedModality::Text],
929 output: vec![SupportedModality::Text],
930 },
931 }),
932 topology: self.config.topology.clone(),
933 silent,
934 organization: self.config.organization,
935 template_filename: paths.get_template_filename().clone(),
936 generation_config: paths.get_gen_conf_filename().cloned(),
937 config,
938 imatrix: self.config.imatrix.clone(),
939 mapper: pipeline_mapper,
940 })))
941 }
942
943 fn get_id(&self) -> String {
944 self.model_id.clone()
945 }
946
947 fn get_kind(&self) -> ModelKind {
948 self.kind.clone()
949 }
950}
951
952impl PreProcessingMixin for NormalPipeline {
953 fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
954 Some(self.chat_template.clone())
955 }
956 fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
957 None
958 }
959}
960
961impl IsqPipelineMixin for NormalPipeline {
962 fn re_isq_model(&mut self, dtype: IsqType) -> Result<()> {
963 let device = self.device().clone();
964 let multi_progress = Arc::new(MultiProgress::new());
965 self.model.quantize(
966 Some(dtype),
967 device.clone(),
968 self.topology.as_ref(),
969 self.silent,
970 self.imatrix.as_ref().map(ImatrixDataSource::File),
971 self.organization,
972 None,
973 UqffFullSer {
974 tokenizer: &self.tokenizer,
975 template_filename: &self.template_filename,
976 generation_config: self.generation_config.as_ref(),
977 config: self.config.clone(),
978 processor_filename: &None,
979 preprocessor_filename: &None,
980 },
981 multi_progress.clone(),
982 )?;
983 Ok(())
984 }
985}
986
987impl CacheManagerMixin for NormalPipeline {
988 fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
989 if matches!(self.model.cache(), EitherCache::Full(_)) {
990 FullCacheManager.clone_in_cache(self, seqs, false)
991 } else {
992 NormalCacheManager.clone_in_cache(self, seqs, false)
993 }
994 }
995 fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
996 if matches!(self.model.cache(), EitherCache::Full(_)) {
997 FullCacheManager.clone_out_cache(self, seqs, false)
998 } else {
999 NormalCacheManager.clone_out_cache(self, seqs, false)
1000 }
1001 }
1002 fn set_none_cache(
1003 &self,
1004 seqs: &mut [&mut Sequence],
1005 reset_non_granular: bool,
1006 modify_draft_cache: bool,
1007 load_preallocated_cache: bool,
1008 ) {
1009 if matches!(self.model.cache(), EitherCache::Full(_)) {
1010 FullCacheManager.set_none_cache(self, seqs, modify_draft_cache, false);
1011 } else {
1012 NormalCacheManager.set_none_cache(
1013 self,
1014 seqs,
1015 modify_draft_cache,
1016 load_preallocated_cache,
1017 );
1018 }
1019 if reset_non_granular {
1020 self.reset_non_granular_state()
1021 }
1022 }
1023 fn cache(&self) -> &EitherCache {
1024 self.model.cache()
1025 }
1026}
1027
1028impl MetadataMixin for NormalPipeline {
1029 fn device(&self) -> Device {
1030 self.model.device().clone()
1031 }
1032 fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
1033 Some(self.tokenizer.clone())
1034 }
1035 fn name(&self) -> String {
1036 self.model_id.clone()
1037 }
1038 fn reset_non_granular_state(&self) {
1039 if let Some(s) = self.non_granular_state.as_ref() {
1040 *self.cache().full().get_scalings_cache() = None;
1041 *get_mut_arcmutex!(s.non_granular_index) = 0;
1042 }
1043 }
1044 fn get_metadata(&self) -> Arc<GeneralMetadata> {
1045 self.metadata.clone()
1046 }
1047 fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
1048 Some(&*self.mapper)
1049 }
1050}
1051
1052#[async_trait::async_trait]
1053impl Pipeline for NormalPipeline {
1054 fn forward_inputs(
1055 &mut self,
1056 inputs: Box<dyn Any>,
1057 return_raw_logits: bool,
1058 ) -> Result<ForwardInputsResult, candle_core::Error> {
1059 let ModelInputs {
1060 input_ids,
1061 input_ids_full,
1062 seqlen_offsets,
1063 seqlen_offsets_full,
1064 context_lens,
1065 position_ids,
1066 paged_attn_meta,
1067 flash_meta,
1068 flash_meta_full,
1069 } = *inputs.downcast().expect("Downcast failed.");
1070 let metadata = self.get_metadata();
1071 let paged_attn_meta = match (&metadata.cache_engine, &paged_attn_meta) {
1072 (Some(cache_engine), Some(meta)) => Some((cache_engine, meta)),
1073 (Some(_), None) => {
1074 candle_core::bail!("Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.")
1076 }
1077 (None, Some(_)) => {
1078 candle_core::bail!("Forward step got a PagedAttention input metadata but there is no cache engine. Please raise an issue.")
1080 }
1081 (None, None) => None,
1082 };
1083 let logits = match self.model.is_xlora() {
1084 false => {
1085 let paged_attn_meta = paged_attn_meta
1086 .as_ref()
1087 .map(|meta| (meta.0.get_kv_cache().clone(), meta.1.clone()));
1088
1089 self.model.forward(
1090 &input_ids,
1091 &seqlen_offsets,
1092 context_lens,
1093 position_ids,
1094 paged_attn_meta.as_ref().map(|(a, b)| (a.clone(), b)),
1095 &flash_meta,
1096 )?
1097 }
1098 true => self.model.xlora_forward(
1099 &input_ids,
1100 input_ids_full.as_ref().unwrap_or(&input_ids),
1101 &seqlen_offsets,
1102 seqlen_offsets_full.as_ref().unwrap_or(&seqlen_offsets),
1103 self.no_kv_cache,
1104 &self.non_granular_state,
1105 context_lens,
1106 position_ids,
1107 &flash_meta,
1108 flash_meta_full.as_ref().unwrap_or(&flash_meta),
1109 )?,
1110 };
1111 if return_raw_logits {
1112 Ok(ForwardInputsResult::RawLogits { logits })
1113 } else {
1114 Ok(ForwardInputsResult::CausalGeneration { logits })
1115 }
1116 }
1117 async fn sample_causal_gen(
1118 &self,
1119 seqs: &mut [&mut Sequence],
1120 logits: Vec<Tensor>,
1121 prefix_cacher: &mut PrefixCacheManagerV2,
1122 disable_eos_stop: bool,
1123 rng: Arc<std::sync::Mutex<Isaac64Rng>>,
1124 ) -> Result<(), candle_core::Error> {
1125 sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
1126 }
1127 fn category(&self) -> ModelCategory {
1128 ModelCategory::Text
1129 }
1130}
1131
1132impl AnyMoePipelineMixin for NormalPipeline {
1133 fn amoe_finish_training(&mut self, gate_model_id: Option<String>) -> candle_core::Result<()> {
1134 self.model.finish_training(gate_model_id)
1135 }
1136 fn amoe_layer_vars(&self) -> Vec<Vec<Var>> {
1137 self.model.get_vars()
1138 }
1139 fn amoe_base_model_trainable_params(&self) -> usize {
1140 self.model.trainable_params()
1141 }
1142 fn amoe_take_cached_gating_outputs(&mut self) -> Vec<Tensor> {
1143 self.model.take_cached_gating_outputs()
1144 }
1145 fn amoe_create_layers(
1146 &mut self,
1147 model_ids: Vec<String>,
1148 token: &TokenSource,
1149 revision: Option<String>,
1150 match_regex: &str,
1151 config: crate::amoe::AnyMoeConfig,
1152 dtype: candle_core::DType,
1153 dev: &Device,
1154 (prefix, mlp): (String, String),
1155 layers: Vec<usize>,
1156 expert_type: AnyMoeExpertType,
1157 silent: bool,
1158 gate_model_id: Option<String>,
1159 ) -> candle_core::Result<()> {
1160 let mut vbs = Vec::new();
1161 let regex = Regex::new(match_regex).map_err(candle_core::Error::msg)?;
1163 for model_id in model_ids {
1164 let model_id_str = &model_id;
1165 let model_id = Path::new(&model_id);
1166
1167 let api = {
1168 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1169 let mut api = ApiBuilder::from_cache(cache)
1170 .with_progress(!silent)
1171 .with_token(get_token(token).map_err(candle_core::Error::msg)?);
1172 if let Ok(x) = std::env::var("HF_HUB_CACHE") {
1173 api = api.with_cache_dir(x.into());
1174 }
1175 api.build().map_err(candle_core::Error::msg)?
1176 };
1177 let revision = revision.clone().unwrap_or("main".to_string());
1178 let api = api.repo(Repo::with_revision(
1179 model_id_str.clone(),
1180 RepoType::Model,
1181 revision.clone(),
1182 ));
1183
1184 let mut filenames = vec![];
1185 for rfilename in
1186 api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors"))
1187 {
1188 filenames.push(api_get_file!(api, &rfilename, model_id));
1189 }
1190
1191 let regex = regex.clone();
1192 let match_regex_clone = match_regex.to_string();
1193 let layers_clone = layers.clone();
1194 let vb = from_mmaped_safetensors(
1195 filenames,
1196 vec![],
1197 Some(dtype),
1198 dev,
1199 vec![None],
1200 silent,
1201 None,
1202 move |key| {
1203 if regex.is_match(&key) {
1204 let last_layer_idx = key.find(&match_regex_clone).unwrap() - 1;
1207 let first_layer_idx = key[..last_layer_idx].rfind('.').unwrap();
1208 let layer_n = key[first_layer_idx + 1..last_layer_idx]
1209 .parse::<usize>()
1210 .unwrap();
1211 layers_clone.contains(&layer_n) || layers_clone.is_empty()
1212 } else {
1213 false
1214 }
1215 },
1216 Arc::new(|_| DeviceForLoadTensor::Base),
1217 )?;
1218 vbs.push(vb);
1219 }
1220
1221 let gate_vb = if let Some(gate_model_id) = gate_model_id {
1222 let model_id_str = &gate_model_id;
1223 let model_id = Path::new(&gate_model_id);
1224
1225 let api = {
1226 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1227 let mut api = ApiBuilder::from_cache(cache)
1228 .with_progress(!silent)
1229 .with_token(get_token(token).map_err(candle_core::Error::msg)?);
1230 if let Ok(x) = std::env::var("HF_HUB_CACHE") {
1231 api = api.with_cache_dir(x.into());
1232 }
1233 api.build().map_err(candle_core::Error::msg)?
1234 };
1235 let revision = revision.clone().unwrap_or("main".to_string());
1236 let api = api.repo(Repo::with_revision(
1237 model_id_str.clone(),
1238 RepoType::Model,
1239 revision.clone(),
1240 ));
1241
1242 let mut gate_filenames = vec![];
1243 for rfilename in
1244 api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors"))
1245 {
1246 gate_filenames.push(api_get_file!(api, &rfilename, model_id));
1247 }
1248 assert_eq!(
1249 gate_filenames.len(),
1250 1,
1251 "Gate model ID must contain only one .safetensors file"
1252 );
1253
1254 let vb = from_mmaped_safetensors(
1255 gate_filenames.clone(),
1256 vec![],
1257 Some(dtype),
1258 dev,
1259 vec![None],
1260 silent,
1261 None,
1262 |_| true,
1263 Arc::new(|_| DeviceForLoadTensor::Base),
1264 )?;
1265 info!(
1266 "Loaded gating layers from `{}`",
1267 gate_filenames[0].display()
1268 );
1269 Some(vb)
1270 } else {
1271 None
1272 };
1273
1274 self.model.create_anymoe_layers(
1275 vbs.clone(),
1276 config.clone(),
1277 (prefix.clone(), mlp.clone()),
1278 layers.clone(),
1279 expert_type.clone(),
1280 gate_vb.clone(),
1281 )?;
1282
1283 Ok(())
1284 }
1285 fn amoe_supported(&self) -> bool {
1286 self.model.amoe_supported()
1287 }
1288}