1use super::inputs_processor::DEFAULT_PROMPT_CHUNK_SIZE;
2use super::isq::ImatrixDataSource;
3use super::llg::build_llg_factory;
4use super::{
5 get_model_paths, get_xlora_paths, text_models_inputs_processor::ModelInputs, AdapterKind,
6 CacheManager, GeneralMetadata, Loader, ModelKind, ModelPaths, NormalModel, NormalModelLoader,
7 TokenSource,
8};
9use super::{
10 AnyMoePipelineMixin, CacheManagerMixin, EitherCache, ForwardInputsResult, IsqOrganization,
11 IsqPipelineMixin, MetadataMixin, ModelCategory, PreProcessingMixin,
12};
13use super::{
14 AutoNormalLoader, DeepSeekV2Loader, DeepSeekV3Loader, GLM4Loader, Gemma2Loader, GemmaLoader,
15 LlamaLoader, MistralLoader, MixtralLoader, NormalLoaderType, Phi2Loader, Phi3Loader,
16 Phi3_5MoELoader, Qwen2Loader, Qwen3Loader, Qwen3MoELoader, SmolLm3Loader, Starcoder2Loader,
17};
18use crate::amoe::AnyMoeExpertType;
19use crate::device_map::{self, DeviceMapper};
20use crate::distributed::{self, WorkerTransferData};
21use crate::kv_cache::{FullCacheManager, NormalCacheManager};
22use crate::lora::Ordering;
23use crate::paged_attention::{calculate_cache_config, AttentionImplementation, CacheEngine};
24use crate::pipeline::chat_template::{calculate_eos_tokens, GenerationConfig};
25use crate::pipeline::isq::UqffFullSer;
26use crate::pipeline::loaders::auto_device_map;
27use crate::pipeline::loaders::QuantizationConfigShim;
28use crate::pipeline::sampling::sample_and_add_toks;
29use crate::pipeline::text_models_inputs_processor::make_prompt_chunk;
30use crate::pipeline::{get_chat_template, Modalities, SupportedModality};
31use crate::pipeline::{ChatTemplate, LocalModelPaths};
32use crate::prefix_cacher::PrefixCacheManagerV2;
33use crate::sequence::Sequence;
34use crate::utils::tokenizer::get_tokenizer;
35use crate::utils::varbuilder_utils::DeviceForLoadTensor;
36use crate::utils::{tokens::get_token, varbuilder_utils::from_mmaped_safetensors};
37use crate::xlora_models::NonGranularState;
38use crate::{
39 api_dir_list, api_get_file, get_mut_arcmutex, get_paths, get_uqff_paths, lora_model_loader,
40 normal_model_loader, normal_model_loader_sharded, xlora_model_loader, DeviceMapSetting,
41 PagedAttentionConfig, Pipeline, Topology, TryIntoDType, GLOBAL_HF_CACHE,
42};
43use anyhow::Result;
44use candle_core::{Device, Tensor, Var};
45use hf_hub::Cache;
46use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
47use indicatif::MultiProgress;
48use mistralrs_quant::log::once_log_info;
49use mistralrs_quant::{AfqLayer, GgufMatMul, HqqLayer, IsqType, QuantizedSerdeType};
50use rand_isaac::Isaac64Rng;
51use regex_automata::meta::Regex;
52use std::any::Any;
53use std::borrow::Cow;
54use std::num::{NonZero, NonZeroUsize};
55use std::path::{Path, PathBuf};
56use std::str::FromStr;
57use std::sync::{Arc, RwLock};
58use std::time::Instant;
59use std::{env, fs};
60use tokenizers::Tokenizer;
61use tokio::sync::Mutex;
62use tracing::{info, warn};
63
64pub struct NormalPipeline {
65 model: Box<dyn NormalModel + Send + Sync>,
66 tokenizer: Arc<Tokenizer>,
67 no_kv_cache: bool,
68 chat_template: Arc<ChatTemplate>,
69 non_granular_state: Option<NonGranularState>,
70 model_id: String,
71 metadata: Arc<GeneralMetadata>,
72 topology: Option<Topology>,
73 silent: bool,
74 organization: IsqOrganization,
75 template_filename: Option<PathBuf>,
77 generation_config: Option<PathBuf>,
78 config: String,
79 imatrix: Option<PathBuf>,
80 mapper: Box<dyn DeviceMapper + Send + Sync>,
81}
82
83pub struct NormalLoader {
85 inner: Box<dyn NormalModelLoader>,
86 model_id: String,
87 config: NormalSpecificConfig,
88 xlora_model_id: Option<String>,
89 lora_adapter_ids: Option<Vec<String>>,
90 kind: ModelKind,
91 xlora_order: Option<Ordering>,
92 no_kv_cache: bool,
93 chat_template: Option<String>,
94 tokenizer_json: Option<String>,
95 tgt_non_granular_index: Option<usize>,
96 token_source: RwLock<Option<TokenSource>>,
97 revision: RwLock<Option<String>>,
98 from_uqff: RwLock<Option<Vec<PathBuf>>>,
99 jinja_explicit: Option<String>,
100 hf_cache_path: Option<PathBuf>,
101}
102
103#[derive(Default)]
104pub struct NormalLoaderBuilder {
106 model_id: Option<String>,
107 config: NormalSpecificConfig,
108 xlora_model_id: Option<String>,
109 lora_adapter_ids: Option<Vec<String>>,
110 kind: ModelKind,
111 xlora_order: Option<Ordering>,
112 no_kv_cache: bool,
113 chat_template: Option<String>,
114 tokenizer_json: Option<String>,
115 tgt_non_granular_index: Option<usize>,
116 jinja_explicit: Option<String>,
117 hf_cache_path: Option<PathBuf>,
118}
119
120#[derive(Clone, Default)]
121pub struct NormalSpecificConfig {
123 pub prompt_chunksize: Option<NonZeroUsize>,
124 pub topology: Option<Topology>,
125 pub organization: IsqOrganization,
126 pub write_uqff: Option<PathBuf>,
127 pub from_uqff: Option<Vec<PathBuf>>,
128 pub imatrix: Option<PathBuf>,
129 pub calibration_file: Option<PathBuf>,
130 pub hf_cache_path: Option<PathBuf>,
131}
132
133impl NormalLoaderBuilder {
134 pub fn new(
135 config: NormalSpecificConfig,
136 chat_template: Option<String>,
137 tokenizer_json: Option<String>,
138 model_id: Option<String>,
139 no_kv_cache: bool,
140 jinja_explicit: Option<String>,
141 ) -> Self {
142 Self {
143 config,
144 chat_template,
145 tokenizer_json,
146 model_id,
147 kind: ModelKind::Normal,
148 jinja_explicit,
149 no_kv_cache,
150 ..Default::default()
151 }
152 }
153
154 fn with_adapter(
155 mut self,
156 xlora_model_id: String,
157 xlora_order: Ordering,
158 no_kv_cache: bool,
159 tgt_non_granular_index: Option<usize>,
160 ) -> Self {
161 self.xlora_model_id = Some(xlora_model_id);
162 self.xlora_order = Some(xlora_order);
163 self.no_kv_cache = no_kv_cache;
164 self.tgt_non_granular_index = tgt_non_granular_index;
165 self.model_id = if let Some(id) = self.model_id {
166 Some(id)
167 } else {
168 info!(
169 "Using adapter base model ID: `{}`",
170 self.xlora_order.as_ref().unwrap().base_model_id
171 );
172 Some(self.xlora_order.as_ref().unwrap().base_model_id.clone())
173 };
174 self
175 }
176
177 pub fn with_xlora(
178 mut self,
179 xlora_model_id: String,
180 xlora_order: Ordering,
181 no_kv_cache: bool,
182 tgt_non_granular_index: Option<usize>,
183 ) -> Self {
184 self.kind = ModelKind::Adapter {
185 adapter: AdapterKind::XLora,
186 };
187 self.with_adapter(
188 xlora_model_id,
189 xlora_order,
190 no_kv_cache,
191 tgt_non_granular_index,
192 )
193 }
194
195 pub fn with_lora(mut self, lora_adapter_ids: Vec<String>) -> Self {
196 self.kind = ModelKind::Adapter {
197 adapter: AdapterKind::Lora,
198 };
199 self.lora_adapter_ids = Some(lora_adapter_ids);
200 self
201 }
202
203 pub fn hf_cache_path(mut self, hf_cache_path: PathBuf) -> Self {
204 self.hf_cache_path = Some(hf_cache_path);
205 self
206 }
207
208 pub fn build(self, loader_tp: Option<NormalLoaderType>) -> anyhow::Result<Box<dyn Loader>> {
211 let loader: Box<dyn NormalModelLoader> = match loader_tp {
212 Some(NormalLoaderType::Mistral) => Box::new(MistralLoader),
213 Some(NormalLoaderType::Gemma) => Box::new(GemmaLoader),
214 Some(NormalLoaderType::Llama) => Box::new(LlamaLoader),
215 Some(NormalLoaderType::Mixtral) => Box::new(MixtralLoader),
216 Some(NormalLoaderType::Phi2) => Box::new(Phi2Loader),
217 Some(NormalLoaderType::Phi3) => Box::new(Phi3Loader),
218 Some(NormalLoaderType::Qwen2) => Box::new(Qwen2Loader),
219 Some(NormalLoaderType::Gemma2) => Box::new(Gemma2Loader),
220 Some(NormalLoaderType::Starcoder2) => Box::new(Starcoder2Loader),
221 Some(NormalLoaderType::Phi3_5MoE) => Box::new(Phi3_5MoELoader),
222 Some(NormalLoaderType::DeepSeekV2) => Box::new(DeepSeekV2Loader),
223 Some(NormalLoaderType::DeepSeekV3) => Box::new(DeepSeekV3Loader),
224 Some(NormalLoaderType::Qwen3) => Box::new(Qwen3Loader),
225 Some(NormalLoaderType::GLM4) => Box::new(GLM4Loader),
226 Some(NormalLoaderType::Qwen3Moe) => Box::new(Qwen3MoELoader),
227 Some(NormalLoaderType::SmolLm3) => Box::new(SmolLm3Loader),
228 None => Box::new(AutoNormalLoader),
229 };
230 Ok(Box::new(NormalLoader {
231 inner: loader,
232 model_id: self.model_id.unwrap(),
233 config: self.config,
234 xlora_model_id: self.xlora_model_id,
235 lora_adapter_ids: self.lora_adapter_ids,
236 kind: self.kind,
237 xlora_order: self.xlora_order,
238 no_kv_cache: self.no_kv_cache,
239 chat_template: self.chat_template,
240 tokenizer_json: self.tokenizer_json,
241 tgt_non_granular_index: self.tgt_non_granular_index,
242 jinja_explicit: self.jinja_explicit,
243 token_source: RwLock::new(None),
244 revision: RwLock::new(None),
245 from_uqff: RwLock::new(None),
246 hf_cache_path: self.hf_cache_path,
247 }))
248 }
249}
250
251impl Loader for NormalLoader {
252 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
253 fn load_model_from_hf(
254 &self,
255 revision: Option<String>,
256 token_source: TokenSource,
257 dtype: &dyn TryIntoDType,
258 device: &Device,
259 silent: bool,
260 mapper: DeviceMapSetting,
261 in_situ_quant: Option<IsqType>,
262 paged_attn_config: Option<PagedAttentionConfig>,
263 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
264 let cache = self
265 .hf_cache_path
266 .clone()
267 .map(Cache::new)
268 .unwrap_or_default();
269 GLOBAL_HF_CACHE.get_or_init(|| cache);
270
271 let paths: anyhow::Result<Box<dyn ModelPaths>> = get_paths!(
272 LocalModelPaths,
273 &token_source,
274 revision.clone(),
275 self,
276 None,
277 None,
278 silent,
279 self.config.from_uqff.is_some()
280 );
281 if let Some(from_uqff) = self.config.from_uqff.clone() {
282 *self.from_uqff.write().unwrap() = Some(get_uqff_paths!(&from_uqff, self, silent));
283 }
284 *self
285 .token_source
286 .write()
287 .expect("Failed to write to token source") = Some(token_source);
288 *self.revision.write().expect("Failed to write to revision") = revision;
289 self.load_model_from_path(
290 &paths?,
291 dtype,
292 device,
293 silent,
294 mapper,
295 in_situ_quant,
296 paged_attn_config,
297 )
298 }
299
300 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
301 fn load_model_from_path(
302 &self,
303 paths: &Box<dyn ModelPaths>,
304 dtype: &dyn TryIntoDType,
305 device: &Device,
306 silent: bool,
307 mut mapper: DeviceMapSetting,
308 mut in_situ_quant: Option<IsqType>,
309 mut paged_attn_config: Option<PagedAttentionConfig>,
310 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
311 let config = std::fs::read_to_string(paths.get_config_filename())?;
312
313 if !self.inner.supports_paged_attention(&config)? {
314 paged_attn_config = None;
315 }
316
317 let prompt_chunksize = self
319 .config
320 .prompt_chunksize
321 .unwrap_or(DEFAULT_PROMPT_CHUNK_SIZE.try_into().unwrap())
322 .get();
323
324 info!("Prompt chunk size is {prompt_chunksize}.",);
325
326 let use_nccl = mistralrs_quant::distributed::use_nccl();
327
328 let available_devices = if let Ok(payload) = env::var(distributed::IS_DAEMON_FLAG) {
329 let payload: WorkerTransferData = serde_json::from_str(&payload)?;
330 let WorkerTransferData::Init { id: _, worker_rank } = payload;
331 vec![candle_core::Device::new_cuda(worker_rank + 1)?]
332 } else if use_nccl {
333 vec![candle_core::Device::new_cuda(0)?]
334 } else {
335 device_map::get_all_similar_devices(device)?
336 };
337 let device = if use_nccl || cfg!(feature = "ring") {
338 available_devices[0].clone()
339 } else {
340 device.clone()
341 };
342
343 if use_nccl || cfg!(feature = "ring") {
345 mapper = DeviceMapSetting::DummyNccl {
346 nm_device: available_devices[0].clone(),
347 };
348 } else if let DeviceMapSetting::Auto(params) = mapper.clone() {
349 let dtype = dtype.try_into_dtype(&available_devices.iter().collect::<Vec<_>>())?;
351
352 if QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)? != 1 {
354 in_situ_quant = None;
355 }
356
357 let (layer_sizes_in_bytes, non_mapped_size_in_bytes, total_model_size_in_bytes) =
360 if let Some(serialized) = &*self.from_uqff.read().unwrap() {
361 let weight_pack_factor = {
362 let ser_artifacts = unsafe {
363 candle_core::safetensors::MmapedSafetensors::multi(serialized)?
364 };
365 let mut total_pack_factors = 0;
366 let total_tensors = ser_artifacts.tensors().len();
367 for (_, artifact) in ser_artifacts.tensors() {
368 let artifact = artifact.data();
369 let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
371 let pack_factor = match QuantizedSerdeType::try_from(isq_type as usize)?
372 {
373 QuantizedSerdeType::Hqq => {
374 HqqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
375 .pack_factor(dtype)
376 }
377 QuantizedSerdeType::Gguf => {
378 GgufMatMul::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
379 .pack_factor(dtype)
380 }
381 QuantizedSerdeType::Fp8 => IsqType::F8E4M3.pack_factor(dtype),
382 QuantizedSerdeType::Unquant => 1,
383 QuantizedSerdeType::Afq => {
384 AfqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
385 .pack_factor(dtype)
386 }
387 };
388 total_pack_factors += pack_factor;
389 }
390
391 total_pack_factors / total_tensors
392 };
393
394 let layer_sizes_in_bytes =
395 self.inner
396 .layer_sizes_in_bytes(&config, dtype, weight_pack_factor)?;
397 let non_mapped_size_in_bytes =
398 self.inner
399 .non_mapped_size_in_bytes(&config, dtype, weight_pack_factor)?;
400 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
401 (
402 layer_sizes_in_bytes,
403 non_mapped_size_in_bytes,
404 layer_sizes_sum + non_mapped_size_in_bytes,
405 )
406 } else if let Some(isq) = in_situ_quant {
407 let weight_pack_factor = isq.pack_factor(dtype);
408 let layer_sizes_in_bytes =
409 self.inner
410 .layer_sizes_in_bytes(&config, dtype, weight_pack_factor)?;
411 let non_mapped_size_in_bytes =
412 self.inner
413 .non_mapped_size_in_bytes(&config, dtype, weight_pack_factor)?;
414 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
415 (
416 layer_sizes_in_bytes,
417 non_mapped_size_in_bytes,
418 layer_sizes_sum + non_mapped_size_in_bytes,
419 )
420 } else {
421 let weight_pack_factor =
423 QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)?;
424 let layer_sizes_in_bytes =
425 self.inner
426 .layer_sizes_in_bytes(&config, dtype, weight_pack_factor)?;
427 let non_mapped_size_in_bytes =
428 self.inner
429 .non_mapped_size_in_bytes(&config, dtype, weight_pack_factor)?;
430 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
431 (
432 layer_sizes_in_bytes,
433 non_mapped_size_in_bytes,
434 layer_sizes_sum + non_mapped_size_in_bytes,
435 )
436 };
437
438 let new = auto_device_map::get_device_layers(
439 &*self.inner,
440 &config,
441 self.inner.num_layers(&config)?,
442 layer_sizes_in_bytes,
443 non_mapped_size_in_bytes,
444 total_model_size_in_bytes,
445 &available_devices,
446 dtype,
447 ¶ms,
448 prompt_chunksize,
449 paged_attn_config.as_ref(),
450 )?;
451 mapper = DeviceMapSetting::Map(new);
452 }
453
454 let pipeline_mapper = mapper.into_mapper(
455 self.inner.num_layers(&config)?,
456 &device,
457 self.config.topology.as_ref(),
458 )?;
459 let mapper = mapper.into_mapper(
460 self.inner.num_layers(&config)?,
461 &device,
462 self.config.topology.as_ref(),
463 )?;
464 let mut layer_devices = Vec::new();
465 for layer in 0..self.inner.num_layers(&config)? {
466 let device = mapper.device_for(layer, false).cloned();
467 layer_devices.push(device);
468 }
469 let dtype = mapper.get_min_dtype(dtype)?;
470
471 let mapping_uses_cpu = mapper.get_unique_devices().iter().any(Device::is_cpu);
474 if mapping_uses_cpu {
475 warn!("Device mapping contains a mix of GPU and CPU. There is no CPU support for PagedAttention, disabling PagedAttention.");
476 paged_attn_config = None;
477 }
478
479 info!("Model config: {:?}", self.inner.get_config_repr(&config)?);
480 if crate::using_flash_attn() {
481 once_log_info("FlashAttention is enabled.");
482 }
483
484 let mut loading_isq = if self.config.imatrix.is_none()
486 && self.config.calibration_file.is_none()
487 && !device.is_cuda()
488 && self.config.write_uqff.is_none()
489 && in_situ_quant.is_some()
490 {
491 let predicates = if matches!(self.config.organization, IsqOrganization::MoeExpertsOnly)
492 {
493 self.inner.immediate_isq_predicates_moqe(&config)?
494 } else {
495 self.inner.immediate_isq_predicates(&config)?
496 };
497 info!("Applying ISQ to {in_situ_quant:?}");
498 if predicates.is_empty() {
499 warn!("No predicates for this model and ISQ setting detected. ISQ will not be applied to any weights!");
500 }
501 mistralrs_quant::set_immediate_isq(in_situ_quant, predicates);
502 false
503 } else {
504 in_situ_quant.is_some()
505 };
506
507 if let Some(ref topology) = self.config.topology {
508 loading_isq |= topology
509 .0
510 .iter()
511 .any(|layer| layer.as_ref().is_some_and(|layer| layer.isq.is_some()));
512 }
513
514 if self.config.imatrix.is_some() && self.config.calibration_file.is_some() {
515 anyhow::bail!(
516 "`imatrix` and `calibration_file` were both specified, this is not allowed."
517 );
518 }
519
520 let load_device = if !loading_isq || self.config.calibration_file.is_some() {
522 loading_isq = false;
523 device.clone()
524 } else {
525 Device::Cpu
526 };
527
528 let is_xlora = self.kind.is_adapted_and(|a| a.is_x_lora());
529
530 let attention_mechanism = if paged_attn_config.is_some() {
531 AttentionImplementation::PagedAttention
532 } else {
533 AttentionImplementation::Eager
534 };
535
536 let multi_progress = Arc::new(MultiProgress::new());
537
538 let mut model = if use_nccl || cfg!(feature = "ring") {
539 let (mapper, sharded_vb) = distributed::prepare_distributed_mapper(
540 dtype,
541 &device,
542 &available_devices,
543 silent,
544 &config,
545 loading_isq,
546 self.config.from_uqff.is_some(),
547 self.config.organization,
548 &*self.inner,
549 paths.as_ref(),
550 )?;
551
552 match self.kind {
554 ModelKind::Normal => normal_model_loader_sharded!(
555 sharded_vb,
556 config,
557 self.inner,
558 mapper,
559 loading_isq,
560 device.clone(),
561 attention_mechanism,
562 multi_progress.clone(),
563 ),
564 ModelKind::Adapter {
565 adapter: AdapterKind::XLora,
566 } => xlora_model_loader!(
567 paths,
568 Some(dtype),
569 &load_device,
570 layer_devices.clone(),
571 config,
572 self.inner,
573 silent,
574 mapper,
575 loading_isq,
576 device.clone(),
577 multi_progress.clone(),
578 ),
579 ModelKind::Adapter {
580 adapter: AdapterKind::Lora,
581 } => lora_model_loader!(
582 paths,
583 Some(dtype),
584 &load_device,
585 layer_devices.clone(),
586 config,
587 self.inner,
588 silent,
589 mapper,
590 loading_isq,
591 self.config.from_uqff.is_some(),
592 device.clone(),
593 attention_mechanism,
594 matches!(self.config.organization, IsqOrganization::MoeExpertsOnly),
595 multi_progress.clone(),
596 ),
597 _ => unreachable!(),
598 }
599 } else {
600 match self.kind {
601 ModelKind::Normal => normal_model_loader!(
602 paths,
603 Some(dtype),
604 &load_device,
605 layer_devices.clone(),
606 config,
607 self.inner,
608 silent,
609 mapper,
610 loading_isq,
611 self.config.from_uqff.is_some(),
612 device.clone(),
613 attention_mechanism,
614 matches!(self.config.organization, IsqOrganization::MoeExpertsOnly),
615 multi_progress.clone(),
616 ),
617 ModelKind::Adapter {
618 adapter: AdapterKind::XLora,
619 } => xlora_model_loader!(
620 paths,
621 Some(dtype),
622 &load_device,
623 layer_devices.clone(),
624 config,
625 self.inner,
626 silent,
627 mapper,
628 loading_isq,
629 device.clone(),
630 multi_progress.clone(),
631 ),
632 ModelKind::Adapter {
633 adapter: AdapterKind::Lora,
634 } => lora_model_loader!(
635 paths,
636 Some(dtype),
637 &load_device,
638 layer_devices.clone(),
639 config,
640 self.inner,
641 silent,
642 mapper,
643 loading_isq,
644 self.config.from_uqff.is_some(),
645 device.clone(),
646 attention_mechanism,
647 matches!(self.config.organization, IsqOrganization::MoeExpertsOnly),
648 multi_progress.clone(),
649 ),
650 _ => unreachable!(),
651 }
652 };
653
654 let tokenizer = get_tokenizer(paths.get_tokenizer_filename(), None)?;
655 let gen_conf: Option<GenerationConfig> = paths.get_gen_conf_filename().and_then(|f| {
656 match serde_json::from_str::<GenerationConfig>(&fs::read_to_string(f).unwrap()) {
657 Ok(conf) => Some(conf),
658 Err(e) => {
659 warn!("Failed to parse generation_config.json: {}", e);
660 None
661 }
662 }
663 });
664
665 let chat_template_explicit = paths
666 .get_chat_template_explicit()
667 .as_ref()
668 .map(|x| x.to_string_lossy().to_string());
669 let chat_template = get_chat_template(
670 paths,
671 self.jinja_explicit.as_ref(),
672 chat_template_explicit.as_ref(),
673 self.chat_template.as_ref(),
674 None,
675 );
676
677 if let Some(calibration_file) = &self.config.calibration_file {
678 let calibration_data = std::fs::read_to_string(calibration_file)?;
679 let tokens = tokenizer
681 .encode_fast(calibration_data, false)
682 .map_err(anyhow::Error::msg)?
683 .get_ids()
684 .to_vec();
685 info!(
686 "Collecting imatrix from calibration file `{}` of {} tokens.",
687 calibration_file.display(),
688 tokens.len()
689 );
690 let bos_toks = chat_template.bos_tok().map(|b| vec![b]).unwrap_or_default();
691 let bos_tok_id = tokenizer
692 .token_to_id(&bos_toks[0])
693 .expect("Somehow the bos token is not present.");
694
695 match self.config.organization {
696 IsqOrganization::Default => model.begin_track_stats()?,
697 IsqOrganization::MoeExpertsOnly => model.begin_track_stats_moe_experts_only()?,
698 }
699
700 const CHUNK_SIZE: usize = 1024;
701 let n_chunks = tokens.len().div_ceil(CHUNK_SIZE);
702 let start = Instant::now();
703 for (i, chunk) in tokens.chunks(CHUNK_SIZE).enumerate() {
704 let chunk = [vec![bos_tok_id], chunk.to_vec()].concat();
705 let chunk_len = chunk.len();
706
707 let start = Instant::now();
708 let inputs = make_prompt_chunk(
709 0,
710 vec![&chunk],
711 &[0],
712 &load_device,
713 None,
714 false,
715 None,
716 Some(pipeline_mapper.as_ref()),
717 )?;
718
719 model.forward(
720 &inputs.input.to_device(model.device())?,
721 &inputs.positions,
722 inputs.context_lens.clone(),
723 inputs.position_ids.clone(),
724 None,
725 &inputs.flash_meta.clone(),
726 )?;
727
728 match model.cache_mut() {
729 EitherCache::Full(full) => {
730 for layer in &mut *full.lock() {
731 *layer = None
732 }
733 }
734 EitherCache::Normal(normal) => {
735 for layer in &mut *normal.lock().unwrap().0 {
736 layer.reset();
737 }
738 }
739 }
740
741 let end = Instant::now();
742 info!(
743 "Processed chunk {}/{n_chunks} ({chunk_len} tokens), {:.2}s",
744 i + 1,
745 end.duration_since(start).as_secs_f32()
746 );
747 }
748 load_device.synchronize()?;
749 let end = Instant::now();
750 info!(
751 "Finished collecting imatrix in {:.2}s",
752 end.duration_since(start).as_secs_f32()
753 );
754 }
755
756 if (loading_isq || self.config.topology.is_some()) && self.config.from_uqff.is_none() {
758 let imatrix_source = match (
759 self.config.imatrix.as_ref(),
760 self.config.calibration_file.is_some(),
761 ) {
762 (None, false) => None,
763 (Some(file), false) => Some(ImatrixDataSource::File(file)),
764 (None, true) => Some(ImatrixDataSource::Collected),
765 (Some(_), true) => unreachable!(),
766 };
767
768 info!("Applying ISQ to all ranks.");
769
770 let multi_progress = Arc::new(MultiProgress::new());
771
772 model.quantize(
773 in_situ_quant,
774 model.device().clone(),
775 self.config.topology.as_ref(),
776 silent,
777 imatrix_source,
778 self.config.organization,
779 self.config.write_uqff.as_ref(),
780 UqffFullSer {
781 tokenizer: &tokenizer,
782 template_filename: paths.get_template_filename(),
783 generation_config: paths.get_gen_conf_filename(),
784 config: config.clone(),
785 processor_filename: &None,
786 preprocessor_filename: &None,
787 },
788 multi_progress.clone(),
789 )?;
790 } else if let Some(from_uqff) = &*self.from_uqff.read().unwrap() {
791 model.load_from_artifacts(
792 device.clone(),
793 self.config.topology.as_ref(),
794 silent,
795 from_uqff,
796 )?;
797 }
798
799 let paged_attn_config = if matches!(
800 self.kind,
801 ModelKind::Adapter {
802 adapter: AdapterKind::XLora
803 }
804 ) {
805 warn!(
806 "Adapter parallel_models do not currently support PagedAttention, running without"
807 );
808 None
809 } else {
810 paged_attn_config
811 };
812
813 let (cache_config, cache_engine) = if let Some(paged_attn_config) = paged_attn_config {
814 let cache_config = calculate_cache_config(
815 paged_attn_config.mem_gpu,
816 paged_attn_config.mem_cpu,
817 paged_attn_config.block_size,
818 dtype,
819 paged_attn_config.cache_type,
820 model.config(),
821 &device,
822 &pipeline_mapper
823 .get_unique_devices()
824 .into_iter()
825 .map(Some)
826 .collect::<Vec<_>>(),
827 silent,
828 )?;
829
830 let mut layer_devices = Vec::new();
831 for layer in 0..self.inner.num_layers(&config)? {
832 let device = model.get_layers().1.device_for(layer, false).cloned();
833 layer_devices.push(device);
834 }
835 let cache_engine = CacheEngine::new(
836 model.config(),
837 &cache_config,
838 dtype,
839 model.device(),
840 layer_devices.clone(),
841 )?;
842
843 (Some(cache_config), Some(cache_engine))
844 } else {
845 (None, None)
846 };
847
848 let max_seq_len = model.max_seq_len();
849 let llg_factory = build_llg_factory(tokenizer.clone())?;
850 let num_hidden_layers = match model.cache() {
851 EitherCache::Full(full) => full.lock().len(),
852 EitherCache::Normal(normal) => normal.lock().unwrap().0.len(),
853 };
854 let eos = calculate_eos_tokens(&chat_template, gen_conf, &tokenizer);
855 let sliding_window = model.config().sliding_window;
856 let model_metadata = Arc::new(model.config().clone());
857
858 Ok(Arc::new(Mutex::new(NormalPipeline {
859 model,
860 tokenizer: tokenizer.into(),
861 no_kv_cache: self.no_kv_cache,
862 chat_template: Arc::new(chat_template),
863 non_granular_state: self.tgt_non_granular_index.map(|tgt_non_granular_index| {
864 NonGranularState {
865 non_granular_index: Arc::new(Mutex::new(0)),
866 tgt_non_granular_index,
867 }
868 }),
869 model_id: self.model_id.clone(),
870 metadata: Arc::new(GeneralMetadata {
871 max_seq_len,
872 llg_factory: Some(llg_factory),
873 no_kv_cache: self.no_kv_cache,
874 no_prefix_cache: is_xlora,
875 num_hidden_layers,
876 eos_tok: eos,
877 kind: self.kind.clone(),
878 is_xlora,
879 activation_dtype: dtype,
880 sliding_window,
881 cache_config,
882 cache_engine,
883 prompt_chunksize: Some(NonZero::new(prompt_chunksize).unwrap()),
884 model_metadata: Some(model_metadata),
885 modalities: Modalities {
886 input: vec![SupportedModality::Text],
887 output: vec![SupportedModality::Text],
888 },
889 }),
890 topology: self.config.topology.clone(),
891 silent,
892 organization: self.config.organization,
893 template_filename: paths.get_template_filename().clone(),
894 generation_config: paths.get_gen_conf_filename().cloned(),
895 config,
896 imatrix: self.config.imatrix.clone(),
897 mapper: pipeline_mapper,
898 })))
899 }
900
901 fn get_id(&self) -> String {
902 self.model_id.clone()
903 }
904
905 fn get_kind(&self) -> ModelKind {
906 self.kind.clone()
907 }
908}
909
910impl PreProcessingMixin for NormalPipeline {
911 fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
912 Some(self.chat_template.clone())
913 }
914 fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
915 None
916 }
917}
918
919impl IsqPipelineMixin for NormalPipeline {
920 fn re_isq_model(&mut self, dtype: IsqType) -> Result<()> {
921 let device = self.device().clone();
922 let multi_progress = Arc::new(MultiProgress::new());
923 self.model.quantize(
924 Some(dtype),
925 device.clone(),
926 self.topology.as_ref(),
927 self.silent,
928 self.imatrix.as_ref().map(ImatrixDataSource::File),
929 self.organization,
930 None,
931 UqffFullSer {
932 tokenizer: &self.tokenizer,
933 template_filename: &self.template_filename,
934 generation_config: self.generation_config.as_ref(),
935 config: self.config.clone(),
936 processor_filename: &None,
937 preprocessor_filename: &None,
938 },
939 multi_progress.clone(),
940 )?;
941 Ok(())
942 }
943}
944
945impl CacheManagerMixin for NormalPipeline {
946 fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
947 if matches!(self.model.cache(), EitherCache::Full(_)) {
948 FullCacheManager.clone_in_cache(self, seqs, false)
949 } else {
950 NormalCacheManager.clone_in_cache(self, seqs, false)
951 }
952 }
953 fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
954 if matches!(self.model.cache(), EitherCache::Full(_)) {
955 FullCacheManager.clone_out_cache(self, seqs, false)
956 } else {
957 NormalCacheManager.clone_out_cache(self, seqs, false)
958 }
959 }
960 fn set_none_cache(
961 &self,
962 seqs: &mut [&mut Sequence],
963 reset_non_granular: bool,
964 modify_draft_cache: bool,
965 load_preallocated_cache: bool,
966 ) {
967 if matches!(self.model.cache(), EitherCache::Full(_)) {
968 FullCacheManager.set_none_cache(self, seqs, modify_draft_cache, false);
969 } else {
970 NormalCacheManager.set_none_cache(
971 self,
972 seqs,
973 modify_draft_cache,
974 load_preallocated_cache,
975 );
976 }
977 if reset_non_granular {
978 self.reset_non_granular_state()
979 }
980 }
981 fn cache(&self) -> &EitherCache {
982 self.model.cache()
983 }
984}
985
986impl MetadataMixin for NormalPipeline {
987 fn device(&self) -> Device {
988 self.model.device().clone()
989 }
990 fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
991 Some(self.tokenizer.clone())
992 }
993 fn name(&self) -> String {
994 self.model_id.clone()
995 }
996 fn reset_non_granular_state(&self) {
997 if let Some(s) = self.non_granular_state.as_ref() {
998 *self.cache().full().get_scalings_cache() = None;
999 *get_mut_arcmutex!(s.non_granular_index) = 0;
1000 }
1001 }
1002 fn get_metadata(&self) -> Arc<GeneralMetadata> {
1003 self.metadata.clone()
1004 }
1005 fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
1006 Some(&*self.mapper)
1007 }
1008}
1009
1010#[async_trait::async_trait]
1011impl Pipeline for NormalPipeline {
1012 fn forward_inputs(
1013 &mut self,
1014 inputs: Box<dyn Any>,
1015 return_raw_logits: bool,
1016 ) -> Result<ForwardInputsResult, candle_core::Error> {
1017 let ModelInputs {
1018 input_ids,
1019 input_ids_full,
1020 seqlen_offsets,
1021 seqlen_offsets_full,
1022 context_lens,
1023 position_ids,
1024 paged_attn_meta,
1025 flash_meta,
1026 flash_meta_full,
1027 } = *inputs.downcast().expect("Downcast failed.");
1028 let metadata = self.get_metadata();
1029 let paged_attn_meta = match (&metadata.cache_engine, &paged_attn_meta) {
1030 (Some(cache_engine), Some(meta)) => Some((cache_engine, meta)),
1031 (Some(_), None) => {
1032 candle_core::bail!("Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.")
1034 }
1035 (None, Some(_)) => {
1036 candle_core::bail!("Forward step got a PagedAttention input metadata but there is no cache engine. Please raise an issue.")
1038 }
1039 (None, None) => None,
1040 };
1041 let logits = match self.model.is_xlora() {
1042 false => {
1043 let paged_attn_meta = paged_attn_meta
1044 .as_ref()
1045 .map(|meta| (meta.0.get_kv_cache().clone(), meta.1.clone()));
1046
1047 self.model.forward(
1048 &input_ids,
1049 &seqlen_offsets,
1050 context_lens,
1051 position_ids,
1052 paged_attn_meta.as_ref().map(|(a, b)| (a.clone(), b)),
1053 &flash_meta,
1054 )?
1055 }
1056 true => self.model.xlora_forward(
1057 &input_ids,
1058 input_ids_full.as_ref().unwrap_or(&input_ids),
1059 &seqlen_offsets,
1060 seqlen_offsets_full.as_ref().unwrap_or(&seqlen_offsets),
1061 self.no_kv_cache,
1062 &self.non_granular_state,
1063 context_lens,
1064 position_ids,
1065 &flash_meta,
1066 flash_meta_full.as_ref().unwrap_or(&flash_meta),
1067 )?,
1068 };
1069 if return_raw_logits {
1070 Ok(ForwardInputsResult::RawLogits { logits })
1071 } else {
1072 Ok(ForwardInputsResult::CausalGeneration { logits })
1073 }
1074 }
1075 async fn sample_causal_gen(
1076 &self,
1077 seqs: &mut [&mut Sequence],
1078 logits: Vec<Tensor>,
1079 prefix_cacher: &mut PrefixCacheManagerV2,
1080 disable_eos_stop: bool,
1081 rng: Arc<std::sync::Mutex<Isaac64Rng>>,
1082 ) -> Result<(), candle_core::Error> {
1083 sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
1084 }
1085 fn category(&self) -> ModelCategory {
1086 ModelCategory::Text
1087 }
1088}
1089
1090impl AnyMoePipelineMixin for NormalPipeline {
1091 fn amoe_finish_training(&mut self, gate_model_id: Option<String>) -> candle_core::Result<()> {
1092 self.model.finish_training(gate_model_id)
1093 }
1094 fn amoe_layer_vars(&self) -> Vec<Vec<Var>> {
1095 self.model.get_vars()
1096 }
1097 fn amoe_base_model_trainable_params(&self) -> usize {
1098 self.model.trainable_params()
1099 }
1100 fn amoe_take_cached_gating_outputs(&mut self) -> Vec<Tensor> {
1101 self.model.take_cached_gating_outputs()
1102 }
1103 fn amoe_create_layers(
1104 &mut self,
1105 model_ids: Vec<String>,
1106 token: &TokenSource,
1107 revision: Option<String>,
1108 match_regex: &str,
1109 config: crate::amoe::AnyMoeConfig,
1110 dtype: candle_core::DType,
1111 dev: &Device,
1112 (prefix, mlp): (String, String),
1113 layers: Vec<usize>,
1114 expert_type: AnyMoeExpertType,
1115 silent: bool,
1116 gate_model_id: Option<String>,
1117 ) -> candle_core::Result<()> {
1118 let mut vbs = Vec::new();
1119 let regex = Regex::new(match_regex).map_err(candle_core::Error::msg)?;
1121 for model_id in model_ids {
1122 let model_id_str = &model_id;
1123 let model_id = Path::new(&model_id);
1124
1125 let api = {
1126 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1127 let mut api = ApiBuilder::from_cache(cache)
1128 .with_progress(!silent)
1129 .with_token(get_token(token).map_err(candle_core::Error::msg)?);
1130 if let Ok(x) = std::env::var("HF_HUB_CACHE") {
1131 api = api.with_cache_dir(x.into());
1132 }
1133 api.build().map_err(candle_core::Error::msg)?
1134 };
1135 let revision = revision.clone().unwrap_or("main".to_string());
1136 let api = api.repo(Repo::with_revision(
1137 model_id_str.clone(),
1138 RepoType::Model,
1139 revision.clone(),
1140 ));
1141
1142 let mut filenames = vec![];
1143 for rfilename in
1144 api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors"))
1145 {
1146 filenames.push(api_get_file!(api, &rfilename, model_id));
1147 }
1148
1149 let regex = regex.clone();
1150 let match_regex_clone = match_regex.to_string();
1151 let layers_clone = layers.clone();
1152 let vb = from_mmaped_safetensors(
1153 filenames,
1154 vec![],
1155 Some(dtype),
1156 dev,
1157 vec![None],
1158 silent,
1159 None,
1160 move |key| {
1161 if regex.is_match(&key) {
1162 let last_layer_idx = key.find(&match_regex_clone).unwrap() - 1;
1165 let first_layer_idx = key[..last_layer_idx].rfind('.').unwrap();
1166 let layer_n = key[first_layer_idx + 1..last_layer_idx]
1167 .parse::<usize>()
1168 .unwrap();
1169 layers_clone.contains(&layer_n) || layers_clone.is_empty()
1170 } else {
1171 false
1172 }
1173 },
1174 Arc::new(|_| DeviceForLoadTensor::Base),
1175 )?;
1176 vbs.push(vb);
1177 }
1178
1179 let gate_vb = if let Some(gate_model_id) = gate_model_id {
1180 let model_id_str = &gate_model_id;
1181 let model_id = Path::new(&gate_model_id);
1182
1183 let api = {
1184 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1185 let mut api = ApiBuilder::from_cache(cache)
1186 .with_progress(!silent)
1187 .with_token(get_token(token).map_err(candle_core::Error::msg)?);
1188 if let Ok(x) = std::env::var("HF_HUB_CACHE") {
1189 api = api.with_cache_dir(x.into());
1190 }
1191 api.build().map_err(candle_core::Error::msg)?
1192 };
1193 let revision = revision.clone().unwrap_or("main".to_string());
1194 let api = api.repo(Repo::with_revision(
1195 model_id_str.clone(),
1196 RepoType::Model,
1197 revision.clone(),
1198 ));
1199
1200 let mut gate_filenames = vec![];
1201 for rfilename in
1202 api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors"))
1203 {
1204 gate_filenames.push(api_get_file!(api, &rfilename, model_id));
1205 }
1206 assert_eq!(
1207 gate_filenames.len(),
1208 1,
1209 "Gate model ID must contain only one .safetensors file"
1210 );
1211
1212 let vb = from_mmaped_safetensors(
1213 gate_filenames.clone(),
1214 vec![],
1215 Some(dtype),
1216 dev,
1217 vec![None],
1218 silent,
1219 None,
1220 |_| true,
1221 Arc::new(|_| DeviceForLoadTensor::Base),
1222 )?;
1223 info!(
1224 "Loaded gating layers from `{}`",
1225 gate_filenames[0].display()
1226 );
1227 Some(vb)
1228 } else {
1229 None
1230 };
1231
1232 self.model.create_anymoe_layers(
1233 vbs.clone(),
1234 config.clone(),
1235 (prefix.clone(), mlp.clone()),
1236 layers.clone(),
1237 expert_type.clone(),
1238 gate_vb.clone(),
1239 )?;
1240
1241 Ok(())
1242 }
1243 fn amoe_supported(&self) -> bool {
1244 self.model.amoe_supported()
1245 }
1246}